diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index c79efee65e5c5..a44e6c581104a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -536,40 +536,42 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const shader.AddOutput("scales", ShaderUsage::UseUniform); shader.AdditionalImplementation() << R"ADDNL_FN( - var max_values : array; + var a_values : array; )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( - var local_a = input_a[global_idx]; - var max_val = subgroupMax(abs(local_a)); - var max_temp = max(max_val.xy, max_val.zw); - var scale = max(max_temp[0], max_temp[1]); - if (local_idx % sg_size == 0) { - max_values[local_idx / sg_size] = scale; - } + a_values[local_idx] = input_a[global_idx]; workgroupBarrier(); - - if (sg_size == 8) - { - scale = max(max_values[0], max_values[1]); - scale = max(scale, max_values[2]); - scale = max(scale, max_values[3]); - } - else if (sg_size == 16) - { - scale = max(max_values[0], max_values[1]); - } - else - { - scale = max_values[0]; - } - - var norm_a = local_a/scale; - output[global_idx] = pack4x8snorm(vec4(norm_a)); - if (local_idx == 0) - { - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; + if (local_idx < 32) { + var max_val = input_a_value_t(0); + for (var i = 0; i < 32; i++) + { + max_val = max(max_val, abs(a_values[i])); + } + var max_temp = max(max_val.xy, max_val.zw); + var scale = max(max_temp[0], max_temp[1]); + var norm_a = a_values[local_idx]/scale; + output[global_idx] = pack4x8snorm(vec4(norm_a)); + if (local_idx == 0) + { + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx].x = scale/127; + } + } else { + var max_val = input_a_value_t(0); + for (var i = 32; i < 64; i++) + { + max_val = max(max_val, abs(a_values[i])); + } + var max_temp = max(max_val.xy, max_val.zw); + var scale = max(max_temp[0], max_temp[1]); + var norm_a = a_values[local_idx]/scale; + output[global_idx] = pack4x8snorm(vec4(norm_a)); + if (local_idx == 32) + { + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx].y = scale/127; + } } )MAIN_FN"; return Status::OK(); @@ -838,15 +840,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t kBlockSizeA = 128; DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(32); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + quantize_program.SetWorkgroupSize(64); + quantize_program.SetDispatchGroupSize(M * K / 256, 1, 1); TensorShape a_quant_shape{1, M, K / kU32Components}; Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}); + {&a_scale, ProgramTensorMetadataDependency::Rank, gsl::narrow(2)}}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); constexpr uint32_t kTileSize = 64;