Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Optimize MatMulNBits f16 prefill shader for subgroup size 32 #23773

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

daijh
Copy link
Contributor

@daijh daijh commented Feb 21, 2025

This commit optimizes the MatMulNBits f16 prefill shader for devices with a subgroup size of 32.

Testing on Lunar Lake shows a ~5x improvement in prompt processing performance, increasing from 14.02 tps to 69.40 tps.

Before:
model_benchmark.exe -l 1000 -i Phi-3.5-mini-instruct-onnx-web

Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token):
avg (us): 7.13811e+07
avg (tokens/s): 14.0233
p50 (us): 7.13158e+07
stddev (us): 120674
n: 5 * 1001 token(s)

After:
model_benchmark.exe -l 1000 -i Phi-3.5-mini-instruct-onnx-web

Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token):
avg (us): 1.44234e+07
avg (tokens/s): 69.4009
p50 (us): 1.44293e+07
stddev (us): 60263.9
n: 5 * 1001 token(s)

See above.

Description

Motivation and Context

This commit optimizes the MatMulNBits f16 prefill shader for devices
with a subgroup size of 32.

Testing on Lunar Lake shows a ~5x improvement in prompt processing
performance, increasing from 14.02 tps to 69.40 tps.

Before:
model_benchmark.exe -l 1000 -i Phi-3.5-mini-instruct-onnx-web

Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       7.13811e+07
        avg (tokens/s): 14.0233
        p50 (us):       7.13158e+07
        stddev (us):    120674
        n:              5 * 1001 token(s)

After:
model_benchmark.exe -l 1000 -i Phi-3.5-mini-instruct-onnx-web

Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.44234e+07
        avg (tokens/s): 69.4009
        p50 (us):       1.44293e+07
        stddev (us):    60263.9
        n:              5 * 1001 token(s)

See above.
@daijh
Copy link
Contributor Author

daijh commented Feb 21, 2025

prefill shader before:

enable f16;
enable subgroups_f16;
enable subgroups;
const workgroup_size_x: u32 = 8;
const workgroup_size_y: u32 = 8;
const workgroup_size_z: u32 = 1;
@group(0) @binding(0) var<storage, read> input_a: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read> input_b: array<vec4<u32>>;
@group(0) @binding(2) var<storage, read> scales: array<f16>;
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
struct Uniforms {
  input_a_shape: vec3<u32>,
  input_a_stride: vec2<u32>,
  input_b_shape: vec3<u32>,
  input_b_stride: vec2<u32>,
  output_shape: vec3<u32>,
  output_stride: vec2<u32>,
  block_size: u32
};
@group(0) @binding(4) var<uniform> uniforms: Uniforms;

alias input_a_value_t = vec4<f16>;
alias input_a_indices_t = vec3<u32>;
fn i2o_input_a(indices : input_a_indices_t)->u32 {
  return indices[0] * uniforms.input_a_stride[0] + indices[1] * uniforms.input_a_stride[1] + indices[2];
}
fn get_input_a_by_indices(indices: input_a_indices_t)->input_a_value_t {
  return input_a[i2o_input_a(indices)];
}
alias input_b_value_t = vec4<u32>;
alias input_b_indices_t = vec3<u32>;
fn i2o_input_b(indices : input_b_indices_t)->u32 {
  return indices[0] * uniforms.input_b_stride[0] + indices[1] * uniforms.input_b_stride[1] + indices[2];
}
fn get_input_b_by_indices(indices: input_b_indices_t)->input_b_value_t {
  return input_b[i2o_input_b(indices)];
}
alias output_value_t = f16;
alias output_indices_t = vec3<u32>;
alias output_element_t = f16;
fn i2o_output(indices : output_indices_t)->u32 {
  return indices[0] * uniforms.output_stride[0] + indices[1] * uniforms.output_stride[1] + indices[2];
}
fn set_output_by_indices(indices: output_indices_t, value: output_value_t) {
  output[i2o_output(indices)]=value;
}

fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {
  if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {
    return get_input_a_by_indices(input_a_indices_t(batch, row, col));
  } else {
    return input_a_value_t(0);
  }
}
var<workgroup> sub_b: array<array<input_b_value_t, 8>, 8>;
var<workgroup> sub_scale: array<array<output_value_t, 8>, 8>;
var<workgroup> inter_results: array<array<array<output_value_t, 8>, 8>,4>;
@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>,
        @builtin(local_invocation_index) local_idx : u32,
        @builtin(local_invocation_id) local_id : vec3<u32>,
 @builtin(subgroup_invocation_id) sg_id : u32,
 @builtin(subgroup_size) sg_size : u32,
        @builtin(num_workgroups) num_workgroups : vec3<u32>) {
  let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;
  let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;
  let col = workgroup_id.x * 8;
  let row = workgroup_id.y * 4;
  let batch = workgroup_id.z;
  let n_blocks_per_col = uniforms.input_b_shape[1];
  let num_tiles = (n_blocks_per_col - 1) / 8 + 1;
  for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
    // load one tile B/scale data into shared memory.
    let b_col = col + local_id.y;
    let block = tile * 8 + local_id.x;
    if (b_col < uniforms.input_b_shape[0] && block < n_blocks_per_col) {
      sub_b[local_id.y][local_id.x] = get_input_b_by_indices(input_b_indices_t(b_col, block, 0));
      sub_scale[local_id.y][local_id.x] = scales[b_col * n_blocks_per_col + block];
    } else {
      sub_b[local_id.y][local_id.x] = input_b_value_t(0);
      sub_scale[local_id.y][local_id.x] = output_value_t(0);
    }
    workgroupBarrier();
    var in_y = (local_idx % 32) / 4;
    var in_x = (local_idx / 32) * 4 + local_idx % 4;
    var word_offset = (local_idx % 4) * 8;
    if (sg_size == 8u) {
      in_y = local_idx % 8;
      in_x = local_idx / 8;
      word_offset = 0u;
    } else if (sg_size == 16u) {
      in_y = (local_idx % 16) / 2;
      in_x = (local_idx / 16) * 2 + local_idx % 2;
      word_offset = (local_idx % 2) * 8;
    } else if (sg_size == 32u) {
      in_y = (local_idx % 32) / 4;
      in_x = (local_idx / 32) * 4 + local_idx % 4;
      word_offset = (local_idx % 4) * 8;
    } else if (sg_size == 64u) {
      in_y = local_idx / 8;
      in_x = local_idx % 8;
      word_offset = (local_idx % 8) * 8;
    }
    let zero_point = output_element_t(8.0);
    let scale = sub_scale[in_y][in_x];
    let b_data = sub_b[in_y][in_x];
    let a_col_start = tile * 64;
    let a_data0 = mm_readA(batch, row + 0, a_col_start + local_idx);
    let a_data1 = mm_readA(batch, row + 1, a_col_start + local_idx);
    let a_data2 = mm_readA(batch, row + 2, a_col_start + local_idx);
    let a_data3 = mm_readA(batch, row + 3, a_col_start + local_idx);
    if (sg_size == 8u) {
      for (var i: u32 = 0; i < 4; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;
        var a0 = subgroupShuffle(a_data0, i * 2);
        var a1 = subgroupShuffle(a_data0, i * 2 + 1);
        inter_results[0][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data1, i * 2);
        a1 = subgroupShuffle(a_data1, i * 2 + 1);
        inter_results[1][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data2, i * 2);
        a1 = subgroupShuffle(a_data2, i * 2 + 1);
        inter_results[2][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data3, i * 2);
        a1 = subgroupShuffle(a_data3, i * 2 + 1);
        inter_results[3][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
      }
    } else if (sg_size == 16u) {
      for (var i: u32 = 0; i < 4; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;
        var a0 = subgroupShuffle(a_data0, i * 2);
        var a00 = subgroupShuffle(a_data0, i * 2 + 8);
        var a1 = subgroupShuffle(a_data0, i * 2 + 1);
        var a11 = subgroupShuffle(a_data0, i * 2 + 9);
        inter_results[0][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data1, i * 2);
        a00 = subgroupShuffle(a_data1, i * 2 + 8);
        a1 = subgroupShuffle(a_data1, i * 2 + 1);
        a11 = subgroupShuffle(a_data1, i * 2 + 9);
        inter_results[1][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data2, i * 2);
        a00 = subgroupShuffle(a_data2, i * 2 + 8);
        a1 = subgroupShuffle(a_data2, i * 2 + 1);
        a11 = subgroupShuffle(a_data2, i * 2 + 9);
        inter_results[2][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data3, i * 2);
        a00 = subgroupShuffle(a_data3, i * 2 + 8);
        a1 = subgroupShuffle(a_data3, i * 2 + 1);
        a11 = subgroupShuffle(a_data3, i * 2 + 9);
        inter_results[3][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);
        word_offset += 2;
      }
    } else {
      for (var i: u32 = 0; i < 4; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;
        var a0 = subgroupShuffle(a_data0, word_offset);
        var a1 = subgroupShuffle(a_data0, word_offset + 1);
        inter_results[0][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data1, word_offset);
        a1 = subgroupShuffle(a_data1, word_offset + 1);
        inter_results[1][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data2, word_offset);
        a1 = subgroupShuffle(a_data2, word_offset + 1);
        inter_results[2][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        a0 = subgroupShuffle(a_data3, word_offset);
        a1 = subgroupShuffle(a_data3, word_offset + 1);
        inter_results[3][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        word_offset += 2;
      }
    }
    workgroupBarrier();
  }
  if (local_idx < 32) {
    let inner_row = local_idx / 8;
    let inner_col = local_idx % 8;
    var output_value = output_value_t(0);
    for (var b = 0u; b < 8; b++) {
      output_value += inter_results[inner_row][inner_col][b];
    }
    if (row + inner_row < uniforms.output_shape[1] && col + inner_col < uniforms.output_shape[2]) {
      set_output_by_indices(output_indices_t(batch, row + inner_row, col + inner_col), output_value);;
    }
  }

}

@daijh
Copy link
Contributor Author

daijh commented Feb 21, 2025

prefill shader after:


enable f16;
enable subgroups_f16;
enable subgroups;
const workgroup_size_x: u32 = 8;
const workgroup_size_y: u32 = 8;
const workgroup_size_z: u32 = 1;
@group(0) @binding(0) var<storage, read> input_a: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read> input_b: array<vec4<u32>>;
@group(0) @binding(2) var<storage, read> scales: array<f16>;
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
struct Uniforms {
  input_a_shape: vec3<u32>,
  input_a_stride: vec2<u32>,
  input_b_shape: vec3<u32>,
  input_b_stride: vec2<u32>,
  output_shape: vec3<u32>,
  output_stride: vec2<u32>,
  block_size: u32
};
@group(0) @binding(4) var<uniform> uniforms: Uniforms;

alias input_a_value_t = vec4<f16>;
alias input_a_indices_t = vec3<u32>;
fn i2o_input_a(indices : input_a_indices_t)->u32 {
  return indices[0] * uniforms.input_a_stride[0] + indices[1] * uniforms.input_a_stride[1] + indices[2];
}
fn get_input_a_by_indices(indices: input_a_indices_t)->input_a_value_t {
  return input_a[i2o_input_a(indices)];
}
alias input_b_value_t = vec4<u32>;
alias input_b_indices_t = vec3<u32>;
fn i2o_input_b(indices : input_b_indices_t)->u32 {
  return indices[0] * uniforms.input_b_stride[0] + indices[1] * uniforms.input_b_stride[1] + indices[2];
}
fn get_input_b_by_indices(indices: input_b_indices_t)->input_b_value_t {
  return input_b[i2o_input_b(indices)];
}
alias output_value_t = f16;
alias output_indices_t = vec3<u32>;
alias output_element_t = f16;
fn i2o_output(indices : output_indices_t)->u32 {
  return indices[0] * uniforms.output_stride[0] + indices[1] * uniforms.output_stride[1] + indices[2];
}
fn set_output_by_indices(indices: output_indices_t, value: output_value_t) {
  output[i2o_output(indices)]=value;
}

fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {
  if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {
    return get_input_a_by_indices(input_a_indices_t(batch, row, col));
  } else {
    return input_a_value_t(0);
  }
}

var<workgroup> sub_b: array<array<input_b_value_t, 8>, 8>;
var<workgroup> sub_scale: array<array<output_value_t, 8>, 8>;
var<workgroup> inter_results: array<array<array<output_value_t, 8>, 8>,4>;

@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>,
        @builtin(local_invocation_index) local_idx : u32,
        @builtin(local_invocation_id) local_id : vec3<u32>,
 @builtin(subgroup_invocation_id) sg_id : u32,
 @builtin(subgroup_size) sg_size : u32,
        @builtin(num_workgroups) num_workgroups : vec3<u32>) {
  let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;
  let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;
  let col = workgroup_id.x * 8;
  let row = workgroup_id.y * 4;
  let batch = workgroup_id.z;
  let n_blocks_per_col = uniforms.input_b_shape[1];
  let num_tiles =  (n_blocks_per_col - 1) / 8 + 1;

  for (var tile = 0u; tile < num_tiles; tile++) {
    // load one tile B/scale data into shared memory.
    let b_col = col + local_id.y;
    let block = tile * 8 + local_id.x;
    if (b_col < uniforms.input_b_shape[0] && block < n_blocks_per_col) {
      sub_b[local_id.y][local_id.x] = get_input_b_by_indices(input_b_indices_t(b_col, block, 0));
      sub_scale[local_id.y][local_id.x] = scales[b_col * n_blocks_per_col + block];
    } else {
      sub_b[local_id.y][local_id.x] = input_b_value_t(0);
      sub_scale[local_id.y][local_id.x] = output_value_t(0);
    }
    workgroupBarrier();

    var in_y = 0u;
    var in_x = 0u;
    if (sg_size == 8u) {
      in_y = local_idx % 8;
      in_x = local_idx / 8;
    } else if (sg_size == 16u) {
      in_y = (local_idx % 16) / 2;
      in_x = (local_idx / 16) * 2 + local_idx % 2;
    } else if (sg_size == 32u) {
      in_y = (local_idx % 32) / 4;
      in_x = (local_idx / 32) * 4 + local_idx % 4;
    } else if (sg_size == 64u) {
      in_y = local_idx / 8;
      in_x = local_idx % 8;
    }

    let zero_point = output_element_t(8.0);
    let scale = sub_scale[in_y][in_x];
    let b_data = sub_b[in_y][in_x];
    let a_col_start = tile * 64;
    let a_data0 = mm_readA(batch, row + 0, a_col_start + local_idx);
    let a_data1 = mm_readA(batch, row + 1, a_col_start + local_idx);
    let a_data2 = mm_readA(batch, row + 2, a_col_start + local_idx);
    let a_data3 = mm_readA(batch, row + 3, a_col_start + local_idx);

    if (sg_size == 8u) {
      for (var i = 0u; i < 4u; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;

        var a0 = subgroupShuffle(a_data0, i * 2);
        var a1 = subgroupShuffle(a_data0, i * 2 + 1);
        inter_results[0][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data1, i * 2);
        a1 = subgroupShuffle(a_data1, i * 2 + 1);
        inter_results[1][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data2, i * 2);
        a1 = subgroupShuffle(a_data2, i * 2 + 1);
        inter_results[2][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data3, i * 2);
        a1 = subgroupShuffle(a_data3, i * 2 + 1);
        inter_results[3][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
      }
    } else if (sg_size == 16u) {
      for (var i = 0u; i < 4u; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;

        var a0_0 = subgroupShuffle(a_data0, i * 2);
        var a0_1 = subgroupShuffle(a_data0, i * 2 + 8);
        var a1_0 = subgroupShuffle(a_data0, i * 2 + 1);
        var a1_1 = subgroupShuffle(a_data0, i * 2 + 9);
        inter_results[0][in_y][in_x] += dot(select(a0_1, a0_0, (in_x & 1u) == 0u), b_dequantized_values[0]) + dot(select(a1_1, a1_0, (in_x & 1u) == 0u), b_dequantized_values[1]);

        a0_0 = subgroupShuffle(a_data1, i * 2);
        a0_1 = subgroupShuffle(a_data1, i * 2 + 8);
        a1_0 = subgroupShuffle(a_data1, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data1, i * 2 + 9);
        inter_results[1][in_y][in_x] += dot(select(a0_1, a0_0, (in_x & 1u) == 0u), b_dequantized_values[0]) + dot(select(a1_1, a1_0, (in_x & 1u) == 0u), b_dequantized_values[1]);

        a0_0 = subgroupShuffle(a_data2, i * 2);
        a0_1 = subgroupShuffle(a_data2, i * 2 + 8);
        a1_0 = subgroupShuffle(a_data2, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data2, i * 2 + 9);
        inter_results[2][in_y][in_x] += dot(select(a0_1, a0_0, (in_x & 1u) == 0u), b_dequantized_values[0]) + dot(select(a1_1, a1_0, (in_x & 1u) == 0u), b_dequantized_values[1]);

        a0_0 = subgroupShuffle(a_data3, i * 2);
        a0_1 = subgroupShuffle(a_data3, i * 2 + 8);
        a1_0 = subgroupShuffle(a_data3, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data3, i * 2 + 9);
        inter_results[3][in_y][in_x] += dot(select(a0_1, a0_0, (in_x & 1u) == 0u), b_dequantized_values[0]) + dot(select(a1_1, a1_0, (in_x & 1u) == 0u), b_dequantized_values[1]);
      }
    } else if (sg_size == 32u) {
      for (var i = 0u; i < 4u; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;

        let b_col_offset = in_x & 3u;

        var a0_0 = subgroupShuffle(a_data0, i * 2);
        var a0_1 = subgroupShuffle(a_data0, i * 2 + 8);
        var a0_2 = subgroupShuffle(a_data0, i * 2 + 16);
        var a0_3 = subgroupShuffle(a_data0, i * 2 + 24);
        var a1_0 = subgroupShuffle(a_data0, i * 2 + 1);
        var a1_1 = subgroupShuffle(a_data0, i * 2 + 9);
        var a1_2 = subgroupShuffle(a_data0, i * 2 + 17);
        var a1_3 = subgroupShuffle(a_data0, i * 2 + 25);
        if (b_col_offset == 0u) {
          inter_results[0][in_y][in_x] += dot(a0_0, b_dequantized_values[0]) + dot(a1_0, b_dequantized_values[1]);
        } else if (b_col_offset == 1u) {
          inter_results[0][in_y][in_x] += dot(a0_1, b_dequantized_values[0]) + dot(a1_1, b_dequantized_values[1]);
        } else if (b_col_offset == 2u) {
          inter_results[0][in_y][in_x] += dot(a0_2, b_dequantized_values[0]) + dot(a1_2, b_dequantized_values[1]);
        } else {
          inter_results[0][in_y][in_x] += dot(a0_3, b_dequantized_values[0]) + dot(a1_3, b_dequantized_values[1]);
        }

        a0_0 = subgroupShuffle(a_data1, i * 2);
        a0_1 = subgroupShuffle(a_data1, i * 2 + 8);
        a0_2 = subgroupShuffle(a_data1, i * 2 + 16);
        a0_3 = subgroupShuffle(a_data1, i * 2 + 24);
        a1_0 = subgroupShuffle(a_data1, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data1, i * 2 + 9);
        a1_2 = subgroupShuffle(a_data1, i * 2 + 17);
        a1_3 = subgroupShuffle(a_data1, i * 2 + 25);
        if (b_col_offset == 0u) {
          inter_results[1][in_y][in_x] += dot(a0_0, b_dequantized_values[0]) + dot(a1_0, b_dequantized_values[1]);
        } else if (b_col_offset == 1u) {
          inter_results[1][in_y][in_x] += dot(a0_1, b_dequantized_values[0]) + dot(a1_1, b_dequantized_values[1]);
        } else if (b_col_offset == 2u) {
          inter_results[1][in_y][in_x] += dot(a0_2, b_dequantized_values[0]) + dot(a1_2, b_dequantized_values[1]);
        } else {
          inter_results[1][in_y][in_x] += dot(a0_3, b_dequantized_values[0]) + dot(a1_3, b_dequantized_values[1]);
        }

        a0_0 = subgroupShuffle(a_data2, i * 2);
        a0_1 = subgroupShuffle(a_data2, i * 2 + 8);
        a0_2 = subgroupShuffle(a_data2, i * 2 + 16);
        a0_3 = subgroupShuffle(a_data2, i * 2 + 24);
        a1_0 = subgroupShuffle(a_data2, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data2, i * 2 + 9);
        a1_2 = subgroupShuffle(a_data2, i * 2 + 17);
        a1_3 = subgroupShuffle(a_data2, i * 2 + 25);
        if (b_col_offset == 0u) {
          inter_results[2][in_y][in_x] += dot(a0_0, b_dequantized_values[0]) + dot(a1_0, b_dequantized_values[1]);
        } else if (b_col_offset == 1u) {
          inter_results[2][in_y][in_x] += dot(a0_1, b_dequantized_values[0]) + dot(a1_1, b_dequantized_values[1]);
        } else if (b_col_offset == 2u) {
          inter_results[2][in_y][in_x] += dot(a0_2, b_dequantized_values[0]) + dot(a1_2, b_dequantized_values[1]);
        } else {
          inter_results[2][in_y][in_x] += dot(a0_3, b_dequantized_values[0]) + dot(a1_3, b_dequantized_values[1]);
        }

        a0_0 = subgroupShuffle(a_data3, i * 2);
        a0_1 = subgroupShuffle(a_data3, i * 2 + 8);
        a0_2 = subgroupShuffle(a_data3, i * 2 + 16);
        a0_3 = subgroupShuffle(a_data3, i * 2 + 24);
        a1_0 = subgroupShuffle(a_data3, i * 2 + 1);
        a1_1 = subgroupShuffle(a_data3, i * 2 + 9);
        a1_2 = subgroupShuffle(a_data3, i * 2 + 17);
        a1_3 = subgroupShuffle(a_data3, i * 2 + 25);
        if (b_col_offset == 0u) {
          inter_results[3][in_y][in_x] += dot(a0_0, b_dequantized_values[0]) + dot(a1_0, b_dequantized_values[1]);
        } else if (b_col_offset == 1u) {
          inter_results[3][in_y][in_x] += dot(a0_1, b_dequantized_values[0]) + dot(a1_1, b_dequantized_values[1]);
        } else if (b_col_offset == 2u) {
          inter_results[3][in_y][in_x] += dot(a0_2, b_dequantized_values[0]) + dot(a1_2, b_dequantized_values[1]);
        } else {
          inter_results[3][in_y][in_x] += dot(a0_3, b_dequantized_values[0]) + dot(a1_3, b_dequantized_values[1]);
        }
      }
    } else {
      var word_offset = (sg_id % (sg_size / 8u)) * 8u;
      for (var i = 0u; i < 4u; i++) {
        let b_value = b_data[i];
        let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
        let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
        let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
        let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;

        var a0 = subgroupShuffle(a_data0, word_offset);
        var a1 = subgroupShuffle(a_data0, word_offset + 1);
        inter_results[0][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data1, word_offset);
        a1 = subgroupShuffle(a_data1, word_offset + 1);
        inter_results[1][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data2, word_offset);
        a1 = subgroupShuffle(a_data2, word_offset + 1);
        inter_results[2][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);

        a0 = subgroupShuffle(a_data3, word_offset);
        a1 = subgroupShuffle(a_data3, word_offset + 1);
        inter_results[3][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);
        word_offset += 2u;
      }
    }
    workgroupBarrier();
  }

  if (local_idx < 32u) {
    let inner_row = local_idx / 8;
    let inner_col = local_idx % 8;
    var output_value = output_value_t(0);
    for (var b = 0u; b < 8u; b++) {
      output_value += inter_results[inner_row][inner_col][b];
    }
    if (row + inner_row < uniforms.output_shape[1] && col + inner_col < uniforms.output_shape[2]) {
      set_output_by_indices(output_indices_t(batch, row + inner_row, col + inner_col), output_value);;
    }
  }

}

@daijh
Copy link
Contributor Author

daijh commented Feb 21, 2025

@qjia7 @jchen10

@daijh
Copy link
Contributor Author

daijh commented Feb 21, 2025

@guschmue @fs-eire, please take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant