Skip to content

Commit

Permalink
[js/webgpu] optimize MatmulNBits (microsoft#21747)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
See 2x speedup for phi3 on the integrated intel gpu with this
optimization.

The optimization is mainly to store input A's data into local variable
instead of loading them from global memory each time when calculate them
with B data.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
qjia7 authored Aug 23, 2024
1 parent 4af6291 commit 87165b9
Showing 1 changed file with 124 additions and 187 deletions.
311 changes: 124 additions & 187 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { calculateTensorSizeInBytes, DataType } from '../../../wasm-common';
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
Expand All @@ -14,7 +14,6 @@ import {
outputVariable,
ShaderHelper,
tensorTypeToWsglStorageType,
UniformsArrayType,
} from './common';

// TODO support quantization bits not equal to 4
Expand Down Expand Up @@ -60,41 +59,27 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
export const createMatMulNBitsProgramInfo = (
inputs: readonly TensorView[],
attributes: MatMulNBitsAttributes,
maxComputeWorkgroupSizes: [number, number, number],
maxComputeWorkgroupStorageSize: number,
): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const dimAOuter = inputShape[aRank - 2];
const dimInner = attributes.k;
const dimBOuter = attributes.n;
const batchDims = inputShape.slice(0, aRank - 2);
const batchSize = ShapeUtil.size(batchDims);
const blobSize = (attributes.blockSize / 8) * attributes.bits;
const blobSize = inputs[1].dims[2];
const blobSizeInWords = blobSize / 4;
const dataType = inputs[0].dataType;
const outputNumber = getMaxComponents(dimAOuter);
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
const workgroupOutputSize = calculateTensorSizeInBytes(dataType, dimAOuter * nBlocksPerCol)!;
const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize);
const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0;
const components =
!useBlockwiseMatMulNBits || maxNumberOfComponents >= 4
? getMaxComponents(dimBOuter)
: maxNumberOfComponents >= 2 && getMaxComponents(dimBOuter) >= 2
? 2
: 1;
const components = getMaxComponents(dimBOuter);
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const outputNumber = dimAOuter > 1 && (dimBOuter / components) % 2 === 0 ? 2 : 1;
const dispatchSize = ShapeUtil.size(outputShape) / components / outputNumber;

const workgroupSize = 64;

const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits
? []
: [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: attributes.blockSize },
];
const programUniforms: ProgramUniform[] = [];
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
bShape.splice(-1, 1, blobSizeInWords / bComponents);
Expand All @@ -106,6 +91,7 @@ export const createMatMulNBitsProgramInfo = (
}
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputRank = inputShapeTemp.length;
const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
Expand All @@ -119,10 +105,6 @@ export const createMatMulNBitsProgramInfo = (
}
const outputRank = outputShapeTemp.length;
const output = outputVariable('output', inputs[0].dataType, outputRank, components);
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'block_size', type: 'u32' },
];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const qDqDataType = (() => {
Expand All @@ -138,187 +120,146 @@ export const createMatMulNBitsProgramInfo = (
}
})();

const processOneBlock = `
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
${b.indicesSet('b_indices', '2', 'word')};
let b_data = ${b.getByIndices('b_indices')};
for (var i: u32 = 0; i < ${bComponents}; i++) {
let b_value: u32 = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
let b_mask: u32 = 0x0F0F0F0Fu;
let b_value_lower: vec4<u32> = unpack4xU8(b_value & b_mask);
let b_value_upper: vec4<u32> = unpack4xU8((b_value >> 4) & b_mask);
let b_quantized_values = ${qDqDataType}(${Array.from(
const processOneWord = (): string => {
let calcStr = `
// reuse a data
var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)};
var a_data: ${qDqDataType};
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
a_data[j] = ${a.getByOffset('input_offset')};
input_offset++;
}
`;
for (let c = 0; c < components * outputNumber; c++) {
calcStr += `
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
b_value_lower = unpack4xU8(b_value & b_mask);
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
b_quantized_values = ${qDqDataType}(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
let b_dequantized_values = ${(() => {
b_dequantized_values = ${(() => {
if (aComponents === 1) {
return `${qDqDataType}(${Array.from(
{ length: 8 },
(_, i) => `(b_quantized_values[${i}] - zero_point) * scale`,
(_, i) => `(b_quantized_values[${i}] - ${zeroPoints ? `zero_point${c}` : 'zero_point'}) * scale${c}`,
).join(', ')});`;
} else {
return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`;
return `(b_quantized_values - ${qDqDataType}(${Array(8)
.fill(`${zeroPoints ? `zero_point${c}` : 'zero_point'}`)
.join(',')})) * scale${c};`;
}
})()};
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) {
${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)};
${a.indicesSet('a_indices', inputRank - 1, 'word_offset')};
var input_offset = ${a.indicesToOffset('a_indices')};
var a_data: ${qDqDataType};
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
a_data[j] = ${a.getByOffset('input_offset')};
input_offset++;
}
${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${
components > 1 ? '[c]' : ''
} += ${Array.from(
{ length: 8 / aComponents },
(_, i) =>
`${
aComponents === 1
? `a_data[${i}] * b_dequantized_values[${i}]`
: `dot(a_data[${i}], b_dequantized_values[${i}])`
}`,
).join(' + ')};
}
word_offset += ${8 / aComponents};
}
}`;
const updateZeroPointIndex = zeroPoints
? `
zero_point_offset += 4;
if (zero_point_offset == 32) {
zero_point_offset = 0;
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
}`
: '';

return useBlockwiseMatMulNBits
? `
var<workgroup> workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([nBlocksPerCol, 1, 1])}
var a_indices: ${a.type.indices};
var block = local_id.x;
var col = workgroup_id.y;
var batch = workgroup_id.z;
${a.indicesSet('a_indices', '0', 'batch')};
// Two zero points are packed into one byte when uniforms.bits is 4.
for (var c: u32 = 0; c < ${components}; c++) {
let col_times_components_plus_c = col * ${components} + c;
${
zeroPoints
? `
var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2;
var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u);
var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u;
var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u;
var zero_point_nibble_offset: u32 = block & 0x1u;
var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;`
: ''
}
var b_indices: ${b.type.indices};
${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')};
// The scale and zero points are computed per block.
var scales_index = col_times_components_plus_c * ${nBlocksPerCol} + block;
let scale = ${scales.getByOffset('scales_index')};
workgroup_shared[local_id.x * ${outputNumber} + ${Math.floor(c / components)}]${components > 1 ? `[${c % components}]` : ''} += ${Array.from(
{ length: 8 / aComponents },
(_, i) =>
`${
aComponents === 1
? `a_data[${i}] * b_dequantized_values[${i}]`
: `dot(a_data[${i}], b_dequantized_values[${i}])`
}`,
).join(' + ')};
`;
}
return calcStr;
};
const prepareScaleAndZeroPoint = (): string => {
let calcStr = `
var col_index = col * ${components};
${
zeroPoints
? `
let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
var zero_point_byte_count: u32;
var zero_point_word_index: u32;
var zero_point_byte_offset: u32;
let zero_point_nibble_offset: u32 = block & 0x1u;
var zero_point_bits_offset: u32;
var zero_point_word: u32;`
: `
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block * ${attributes.blockSize / aComponents};
var workgroup_shared_offset: u32 = block * ${dimAOuter};
${processOneBlock}
}
workgroupBarrier();
var output_indices: ${output.type.indices};
var elements_per_thread: u32 = ${Math.ceil(dimAOuter / nBlocksPerCol)};
${output.indicesSet('output_indices', '0', 'batch')};
${output.indicesSet('output_indices', outputRank - 1, 'col')};
${output.indicesSet('output_indices', outputRank - 2, 'local_id.x * elements_per_thread')};
var output_offset = ${output.indicesToOffset('output_indices')};
for (var m: u32 = 0u; m < elements_per_thread; m++) {
var row = m + local_id.x * elements_per_thread;
if (row < ${dimAOuter}) {
var output_value: ${output.type.value} = ${output.type.value}(0);
var workgroup_shared_offset: u32 = row;
for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) {
output_value += workgroup_shared[workgroup_shared_offset];
workgroup_shared_offset += ${dimAOuter};
}
${output.setByOffset('output_offset', 'output_value')};
output_offset += ${dimBOuter / components};
let zero_point = ${dataType}(${8.0});`
}
}
}`
: `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var output_values: array<${output.type.value}, ${outputNumber}>;
var output_indices = ${output.offsetToIndices('global_idx')};
var col = ${output.indicesGet('output_indices', outputRank - 1)};
var row = ${output.indicesGet('output_indices', outputRank - 2)};
var a_indices: ${a.type.indices} = output_indices;
// Two zero points are packed into one byte because uniforms.bits <= 4.
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
// TODO support zero_point_offset for bits > 4
${
zeroPoints
? `
var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2);
var zero_point_index: u32 = zero_point_abs_offset / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;`
: ''
}
var scale_index = col * ${nBlocksPerCol * components};
var b_indices: ${b.type.indices};
for (var c: u32 = 0; c < ${components}; c++) {
${b.indicesSet('b_indices', '0', `col * ${components} + c`)};
var block_offset: u32 = 0;
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
// The scale and zero points are computed per block.
let scale = ${scales.getByOffset('scale_index')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
${processOneBlock}
scale_index++;
${updateZeroPointIndex}
block_offset += uniforms.block_size / ${aComponents};
}
// Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
`;
for (let c = 0; c < components * outputNumber; c++) {
calcStr += `
let scale${c} = ${scales.getByOffset(`col_index * nBlocksPerCol + block`)};
${
zeroPoints
? `if (zero_point_offset % 8 > 0) {
${updateZeroPointIndex}
}`
? `
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);
zero_point_word_index = zero_point_byte_count >> 0x2u;
zero_point_byte_offset = zero_point_byte_count & 0x3u;
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);`
: ''
}
col_index += 1;`;
}
return calcStr;
};
const prepareBData = (): string => {
let calcStr = `col_index = col * ${components};`;
for (let c = 0; c < components * outputNumber; c++) {
calcStr += `
let b${c}_data = ${b.getByIndices(`${b.type.indices}(col_index, block, word)`)};
col_index += 1;`;
}
calcStr += `
var b_value: u32;
let b_mask: u32 = 0x0F0F0F0Fu;
var b_value_lower: vec4<u32>;
var b_value_upper: vec4<u32>;
var b_quantized_values: ${qDqDataType};
var b_dequantized_values: ${qDqDataType};`;
return calcStr;
};
return `
var<workgroup> workgroup_shared: array<${output.type.value}, ${outputNumber * workgroupSize}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([workgroupSize, 1, 1])}
let output_indices = ${output.offsetToIndices(`(global_idx / ${workgroupSize}) * ${outputNumber}`)};
let col = output_indices[2];
let row = output_indices[1];
let batch = output_indices[0];
let nBlocksPerCol = uniforms.b_shape[1];
for (var block = local_id.x; block < nBlocksPerCol; block += ${workgroupSize}) {
//process one block
var word_offset: u32 = block * ${attributes.blockSize / aComponents};
${prepareScaleAndZeroPoint()}
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
${prepareBData()}
for (var i: u32 = 0; i < ${bComponents}; i++) {
${processOneWord()}
word_offset += ${8 / aComponents};
}
}
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)};
${output.setByIndices('output_indices', 'output_values[k]')}
}
workgroupBarrier();
if (local_id.x < ${outputNumber}) {
var output_value: ${output.type.value} = ${output.type.value}(0);
var workgroup_shared_offset: u32 = local_id.x;
for (var b: u32 = 0u; b < ${workgroupSize}u; b++) {
output_value += workgroup_shared[workgroup_shared_offset];
workgroup_shared_offset += ${outputNumber};
}
${output.setByIndices(`${output.type.indices}(batch, row, col + local_id.x)`, 'output_value')};
}
}`;
};
return {
name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits',
name: 'MatMulNBits',
shaderCache: {
hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`,
hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components};${outputNumber};${workgroupSize}`,
inputDependencies: Array(inputs.length).fill('rank'),
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType }],
name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits',
dispatchGroup: useBlockwiseMatMulNBits
? { x: 1, y: Math.ceil(dimBOuter / components), z: batchSize }
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
dispatchGroup: { x: dispatchSize },
programUniforms,
}),
getShaderSource,
Expand All @@ -327,11 +268,7 @@ export const createMatMulNBitsProgramInfo = (

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes();
const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize();
context.compute(
createMatMulNBitsProgramInfo(context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize),
);
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
};

export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
Expand Down

0 comments on commit 87165b9

Please sign in to comment.