diff --git a/docs/operators.md b/docs/operators.md index 47b30cee..5d848ae3 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -43,7 +43,7 @@ _This file is automatically generated from the def files via [this script](/tool | [Div](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Div-7) | | [Dropout](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout) | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [7-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-7), [10-11](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-10), [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Dropout-12) | | [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/master/docs/Operators.md#DynamicQuantizeLinear) | | | | -| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | | | | +| [Einsum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Einsum) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | [12+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Einsum-12) | | [Elu](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [6+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Elu-6) | | [Equal](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal) | | | [7-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-7), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Equal-11) | | [Erf](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Erf) | | | | diff --git a/lib/backends/cpu/op-resolve-rules.ts b/lib/backends/cpu/op-resolve-rules.ts index c77a8b7e..2a978b01 100644 --- a/lib/backends/cpu/op-resolve-rules.ts +++ b/lib/backends/cpu/op-resolve-rules.ts @@ -11,6 +11,7 @@ import {CpuCast} from './ops/cast'; import {CpuConcat} from './ops/concat'; import {CpuConv} from './ops/conv'; import {CpuDropout} from './ops/dropout'; +import {CpuEinsum} from './ops/einsum'; import {CpuExpand} from './ops/expand'; import {CpuFlatten} from './ops/flatten'; import {CpuGather} from './ops/gather'; @@ -112,4 +113,5 @@ export const CPU_OP_RESOLVE_RULES: ReadonlyArray = [ ['Upsample', '', '7-8', () => new CpuUpsample()], ['Upsample', '', '9', () => new CpuUpsampleV9()], ['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))], + ['Einsum', '', '12+', () => new CpuEinsum()], ]; diff --git a/lib/backends/cpu/ops/einsum.ts b/lib/backends/cpu/ops/einsum.ts new file mode 100644 index 00000000..6e3378d4 --- /dev/null +++ b/lib/backends/cpu/ops/einsum.ts @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {CpuInferenceHandler} from '../inference-handler'; + +import {ShapeUtil} from './../../../util'; + +export class CpuEinsum extends Einsum { + run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); + + const result = einsum(outputShape, inputs, sizes, outputIndices, inputIndices); + + return [result]; + } +} + +export function einsum( + outputShape: number[], inputs: Tensor[], sizes: number[], outputIndices: number[], + inputIndices: number[][]): Tensor { + const result = new Tensor(outputShape, inputs[0].type); + const totalSize = ShapeUtil.size(sizes); + let i = 0; + const index = new Array(sizes.length).fill(0); + + while (i < totalSize) { + const outputIx: number[] = []; + for (const outputIndex of outputIndices) { + outputIx.push(index[outputIndex]); + } + + let value = 1; + for (let i = 0; i < inputIndices.length; i++) { + const inputIx: number[] = []; + for (const inputIndex of inputIndices[i]) { + inputIx.push(index[inputIndex]); + } + value *= inputs[i].get(inputIx) as number; + } + + result.set(outputIx, result.get(outputIx) as number + value); + + i++; + ShapeUtil.incrementIndex(index, sizes); + } + + return result; +} diff --git a/lib/backends/wasm/op-resolve-rules.ts b/lib/backends/wasm/op-resolve-rules.ts index 92fb6141..8da862b7 100644 --- a/lib/backends/wasm/op-resolve-rules.ts +++ b/lib/backends/wasm/op-resolve-rules.ts @@ -7,6 +7,7 @@ import {WasmBatchNormalization} from './ops/batch-normalization'; import {WasmBinaryOp} from './ops/binary-op'; import {WasmClip} from './ops/clip'; import {WasmConv} from './ops/conv'; +import {WasmEinsum} from './ops/einsum'; import {WasmGemm} from './ops/gemm'; import {WasmInstanceNormalization} from './ops/instance-normalization'; import {WasmMatMul} from './ops/matmul'; @@ -36,4 +37,5 @@ export const WASM_OP_RESOLVE_RULES: ReadonlyArray = [ ['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')], ['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8 ['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')], + ['Einsum', '', '12+', () => new WasmEinsum()], ]; diff --git a/lib/backends/wasm/ops/einsum.ts b/lib/backends/wasm/ops/einsum.ts new file mode 100644 index 00000000..49c6c194 --- /dev/null +++ b/lib/backends/wasm/ops/einsum.ts @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../../../attribute'; +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {WasmBinding} from '../../../wasm-binding'; +import {WasmInferenceHandler} from '../inference-handler'; + +export class WasmEinsum extends Einsum { + initialize(attributes: Attribute): void { + super.initialize(attributes); + if (this.inputs.length > 2) { + throw new Error('Wasm implementation of Einsum currently supports at most 2 inputs'); + } + } + + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); + + const y = new Tensor(outputShape, inputs[0].type); + + if (inputs.length === 2) { + WasmBinding.getInstance().ccall( + '_einsum_f32', + [inputs[0].floatData, 'float32ptr'], + [inputs[1].floatData, 'float32ptr'], + [y.floatData, 'float32ptr', 'inout'], + [sizes, 'int32ptr'], + [sizes.length, 'int32'], + [outputIndices, 'int32ptr'], + [outputIndices.length, 'int32'], + [inputIndices[0], 'int32ptr'], + [inputIndices[0].length, 'int32'], + [inputIndices[1], 'int32ptr'], + [inputIndices[2].length, 'int32'], + ); + } else { + WasmBinding.getInstance().ccall( + '_einsum_single_f32', + [inputs[0].floatData, 'float32ptr'], + [y.floatData, 'float32ptr', 'inout'], + [sizes, 'int32ptr'], + [sizes.length, 'int32'], + [outputIndices, 'int32ptr'], + [outputIndices.length, 'int32'], + [inputIndices[0], 'int32ptr'], + [inputIndices[1].length, 'int32'], + ); + } + + return [y]; + } + + checkInputTypes(inputs: Tensor[]): boolean { + // currently Wasm backend only supports 'float32' input type + if (inputs[0].type !== 'float32' || (inputs.length > 1 && inputs[1].type !== 'float32')) { + return false; + } + + return super.checkInputTypes(inputs); + } +} diff --git a/lib/backends/webgl/op-resolve-rules.ts b/lib/backends/webgl/op-resolve-rules.ts index 7b9d56fe..ec208f71 100644 --- a/lib/backends/webgl/op-resolve-rules.ts +++ b/lib/backends/webgl/op-resolve-rules.ts @@ -10,6 +10,7 @@ import {WebGLClip} from './ops/clip'; import {WebGLConcat} from './ops/concat'; import {WebGLConv} from './ops/conv'; import {WebGLDropout} from './ops/dropout'; +import {WebGLEinsum} from './ops/einsum'; import {WebGLElu} from './ops/elu'; import {WebGLFlatten} from './ops/flatten'; import {WebGLGather} from './ops/gather'; @@ -105,4 +106,5 @@ export const WEBGL_OP_RESOLVE_RULES: ReadonlyArray = [ ['Upsample', '', '7-8', () => new WebGLUpsample()], ['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()], ['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())], + ['Einsum', '', '1+', () => new WebGLEinsum()], ]; diff --git a/lib/backends/webgl/ops/einsum.ts b/lib/backends/webgl/ops/einsum.ts new file mode 100644 index 00000000..233bd451 --- /dev/null +++ b/lib/backends/webgl/ops/einsum.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Einsum} from '../../../ops/einsum'; +import {Tensor} from '../../../tensor'; +import {WebGLInferenceHandler} from '../inference-handler'; +import {ProgramInfo, RunData, WebGLOperator} from '../types'; + +import {ShapeUtil} from './../../../util'; + +const samplerNames = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'.split(''); + +export class WebGLEinsum extends Einsum implements WebGLOperator { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + return inferenceHandler.run(this, inputs); + } + + createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { + const {outputShape, sizes, outputIndices, inputIndices} = this.prepareRun(inputs); + + const sumDims = []; + const sumDimSizes = []; + for (let i = 0; i < sizes.length; i++) { + if (outputIndices.indexOf(i) === -1) { + sumDims.push(i); + sumDimSizes.push(sizes[i]); + } + } + const sumSize = ShapeUtil.size(sumDimSizes); + + let rank = outputShape.length; + // Webgl doesnt like 0 length arrays + if (rank === 0) { + rank = 1; + } + + const initIndex1 = outputIndices.map((x, i) => `index[${x}] = indices[${i}];`).join('\n'); + const initIndex2 = sumDims.map(x => `index[${x}] = 0;`).join('\n'); + + const findInputValues = inputs.map((_, i) => this.buildFindInputValueScript(i, inputIndices[i])).join('\n'); + + const incrementIndex = this.buildIncrementIndexScript(sumDims, sumDimSizes); + + const shaderSource = ` + float process(int indices[${rank}]) { + float value = 0.0; + + int index[${sizes.length}]; + ${initIndex1} + ${initIndex2} + + int i = 0; + while(i < ${sumSize}) { + float add = 1.0; + + ${findInputValues} + + value += add; + + ${incrementIndex} + i++; + } + + return value; + }`; + const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t)); + return { + inputLayouts, + outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape), + samplers: samplerNames.slice(0, inputs.length), + shaderSource, + }; + } + + buildFindInputValueScript(i: number, indices: number[]): string { + const initInputIndex = indices.map((ix, indiceNum) => `input${i}Index[${indiceNum}] = index[${ix}];`).join('\n'); + + const script = `int input${i}Index[${indices.length}]; + ${initInputIndex} + add *= _${samplerNames[i]}(input${i}Index);`; + + return script; + } + + buildIncrementIndexScript(sumDims: number[], sumDimSizes: number[]): string { + let script = ''; + for (let i = 0; i < sumDims.length; i++) { + script += ` + index[${sumDims[i]}] += 1; + if (index[${sumDims[i]}] >= ${sumDimSizes[i]}) { + index[${sumDims[i]}] = 0; + `; + } + for (let i = 0; i < sumDims.length; i++) { + script += '}\n'; + } + + return script; + } + + createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { + const inputTDs = inputs.map((v, i) => inferenceHandler.getOrCreateTextureData(v, programInfo.inputLayouts[i])); + return { + inputTextureDatas: inputTDs, + outputTextureData: + inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type), + uniformData: {} + }; + } +} diff --git a/lib/ops/einsum.ts b/lib/ops/einsum.ts new file mode 100644 index 00000000..f9d352b9 --- /dev/null +++ b/lib/ops/einsum.ts @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {Attribute} from '../attribute'; +import {InferenceHandler} from '../backend'; +import {Operator} from '../operators'; +import {Tensor} from '../tensor'; + +export abstract class Einsum implements Operator { + abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; + + prepareRun(inputs: Tensor[]) { + const dimensionSizeMap: {[name: string]: number} = {}; + this.matchInputs(inputs, dimensionSizeMap); + const outputShape = this.calculateOutputSize(dimensionSizeMap); + + let i = 0; + const sizes = []; + const nameToId: {[name: string]: number} = {}; + const idToName: {[id: number]: string} = {}; + + for (const name in dimensionSizeMap) { + sizes.push(dimensionSizeMap[name]); + nameToId[name] = i; + idToName[i] = name; + i++; + } + + const outputIndices: number[] = []; + const inputIndices: number[][] = []; + for (const outputName of this.outputNames) { + outputIndices.push(nameToId[outputName]); + } + for (let i = 0; i < this.inputs.length; i++) { + const indices = []; + for (const inputName of this.inputNames[i]) { + indices.push(nameToId[inputName]); + } + inputIndices.push(indices); + } + + return {outputShape, sizes, outputIndices, inputIndices}; + } + + initialize(attributes: Attribute): void { + this.equation = attributes.getString('equation'); + const split = this.equation.split('->'); + this.lhs = split[0].trim(); + if (split.length === 2) { + this.rhs = split[1].trim(); + this.implicit = false; + } else { + this.implicit = true; + } + + const lhsSplit = this.lhs.split(','); + this.inputs = lhsSplit.map(v => v.trim()); + + for (let i = 0; i < this.inputs.length; i++) { + this.inputNames.push([]); + this.parseEquationPart(this.inputs[i], this.inputNames[i]); + } + + if (this.rhs) { + this.parseEquationPart(this.rhs, this.outputNames); + } + } + + private parseEquationPart(part: string, indices: string[]) { + for (let i = 0; i < part.length; i++) { + const char = part.charAt(i); + + if (char === '.') { + throw new Error('Use of ellipsis (...) in einsum not yet supported'); + } + + indices.push(char); + } + } + + protected matchInputs(inputs: Tensor[], dimensionSizeMap: {[name: string]: number}) { + for (let i = 0; i < inputs.length; i++) { + this.matchDimensions(this.inputNames[i], inputs[i].dims, dimensionSizeMap); + } + } + + protected calculateOutputSize(dimensionSizeMap: {[name: string]: number}): number[] { + const result: number[] = []; + for (let i = 0; i < this.outputNames.length; i++) { + result.push(dimensionSizeMap[this.outputNames[i]]); + } + return result; + } + + checkInputs(inputs: Tensor[]): boolean { + const dimensionMap: {[id: string]: number} = {}; + + if (inputs.length !== this.inputs.length) { + return false; + } + + for (let i = 0; i < inputs.length; i++) { + if (!this.matchDimensions(this.inputNames[i], inputs[i].dims, dimensionMap)) { + return false; + } + } + + return this.checkInputTypes(inputs); + } + + protected matchDimensions(indices: string[], inputDims: readonly number[], dimensionMap: {[id: string]: number}): + boolean { + for (let j = 0; j < indices.length; j++) { + const ix = indices[j]; + if (dimensionMap[ix] && dimensionMap[ix] !== inputDims[j]) { + return false; + } else if (!dimensionMap[ix]) { + dimensionMap[ix] = inputDims[j]; + } + } + + return true; + } + + protected checkInputTypes(inputs: Tensor[]): boolean { + const allowedTypes = ['float32', 'float64', 'int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32']; + + if (inputs.find((v) => allowedTypes.indexOf(v.type) === -1) !== undefined) { + return false; + } + + const types = inputs.map(v => v.type); + if (types.find(v => v !== types[0]) !== undefined) { + return false; + } + + return true; + } + + protected equation: string; + protected lhs: string; + protected rhs?: string; + + protected inputs: string[] = []; + + // The i-th string[] Maps from input axis i to general axis id + protected inputNames: string[][] = []; + + // Maps from output axis to general axis id + protected outputNames: string[] = []; + + protected implicit: boolean; +} diff --git a/src/wasm-build-config.json b/src/wasm-build-config.json index 95a99ee8..5f6a5f92 100644 --- a/src/wasm-build-config.json +++ b/src/wasm-build-config.json @@ -22,6 +22,8 @@ "_clip_f32", "_instance_normalization_f32", "_sum_f32", - "_softmax_f32" + "_softmax_f32", + "_einsum_f32", + "_einsum_single_f32" ] } diff --git a/src/wasm-ops/einsum.cpp b/src/wasm-ops/einsum.cpp new file mode 100644 index 00000000..bdc0fb68 --- /dev/null +++ b/src/wasm-ops/einsum.cpp @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "einsum.h" +#include "common.h" +#include "utils/shape_utils.h" + +// Wasm interop method +void einsum_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + uint32_t const argc = dataIndex[0]; + + const float *a = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *b = PARAM_FLOAT_PTR(data, dataIndex[2]); + float *y = PARAM_FLOAT_PTR(data, dataIndex[3]); + const int32_t *dims = PARAM_INT32_PTR(data, dataIndex[4]); + const int32_t rank = PARAM_INT32(data, dataIndex[5]); + const int32_t *outputIndices = PARAM_INT32_PTR(data, dataIndex[6]); + const int32_t outputRank = PARAM_INT32(data, dataIndex[7]); + const int32_t *input1Indices = PARAM_INT32_PTR(data, dataIndex[8]); + const int32_t input1Rank = PARAM_INT32(data, dataIndex[9]); + const int32_t *input2Indices = PARAM_INT32_PTR(data, dataIndex[10]); + const int32_t input2Rank = PARAM_INT32(data, dataIndex[11]); + + einsum_f32_imp(a, b, y, dims, rank, outputIndices, outputRank, input1Indices, + input1Rank, input2Indices, input2Rank); +} + +void einsum_f32_imp(const float *A, const float *B, float *Y, + const int32_t *dims, const int32_t rank, + const int32_t *outputIndices, int32_t outputRank, + const int32_t *input1Indices, int32_t input1Rank, + const int32_t *input2Indices, int32_t input2Rank) { + std::vector dimsVector(dims, dims + rank); + // std::vector strides = ShapeUtils::compute_strides(dimsVector); + size_t totalSize = ShapeUtils::size_from_dims(dimsVector); + size_t i = 0; + std::vector index(rank, 0); + + std::vector outputStrides(outputRank, 1); + for (size_t j = outputRank - 2; j >= 0; j--) { + outputStrides[j] = outputStrides[j + 1] * dimsVector[outputIndices[j]]; + } + + std::vector input1Strides(input1Rank, 1); + for (size_t j = input1Rank - 2; j >= 0; j--) { + input1Strides[j] = input1Strides[j + 1] * dimsVector[input1Indices[j]]; + } + + std::vector input2Strides(input2Rank, 1); + for (size_t j = input2Rank - 2; j >= 0; j--) { + input2Strides[j] = input2Strides[j + 1] * dimsVector[input2Indices[j]]; + } + + while (i < totalSize) { + size_t outputOffset = 0; + for (size_t j = 0; j < outputRank; j++) { + outputOffset += index[outputIndices[j]] * outputStrides[j]; + } + + size_t input1Offset = 0; + for (size_t j = 0; j < input1Rank; j++) { + input1Offset += index[input1Indices[j]] * input1Strides[j]; + } + + size_t input2Offset = 0; + for (size_t j = 0; j < input2Rank; j++) { + input2Offset += index[input2Indices[j]] * input2Strides[j]; + } + + Y[outputOffset] += A[input1Offset] * B[input2Offset]; + + i++; + ShapeUtils::increment_index(index, dimsVector, dimsVector.size()); + } +} + +void einsum_single_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + uint32_t const argc = dataIndex[0]; + + const float *a = PARAM_FLOAT_PTR(data, dataIndex[1]); + float *y = PARAM_FLOAT_PTR(data, dataIndex[2]); + const int32_t *dims = PARAM_INT32_PTR(data, dataIndex[3]); + const int32_t rank = PARAM_INT32(data, dataIndex[4]); + const int32_t *outputIndices = PARAM_INT32_PTR(data, dataIndex[5]); + const int32_t outputRank = PARAM_INT32(data, dataIndex[6]); + const int32_t *inputIndices = PARAM_INT32_PTR(data, dataIndex[7]); + const int32_t inputRank = PARAM_INT32(data, dataIndex[8]); + + einsum_single_f32_imp(a, y, dims, rank, outputIndices, outputRank, + inputIndices, inputRank); +} + +// Core operator implementation +void einsum_single_f32_imp(const float *A, float *Y, const int32_t *dims, + const int32_t rank, const int32_t *outputIndices, + int32_t outputRank, const int32_t *inputIndices, + int32_t inputRank) { + std::vector dimsVector(dims, dims + rank); + // std::vector strides = ShapeUtils::compute_strides(dimsVector); + size_t totalSize = ShapeUtils::size_from_dims(dimsVector); + size_t i = 0; + std::vector index(rank, 0); + + std::vector outputStrides(outputRank, 1); + for (size_t j = outputRank - 2; j >= 0; j--) { + outputStrides[j] = outputStrides[j + 1] * dimsVector[outputIndices[j]]; + } + + std::vector inputStrides(inputRank, 1); + for (size_t j = inputRank - 2; j >= 0; j--) { + inputStrides[j] = inputStrides[j + 1] * dimsVector[inputIndices[j]]; + } + + while (i < totalSize) { + size_t outputOffset = 0; + for (size_t j = 0; j < outputRank; j++) { + outputOffset += index[outputIndices[j]] * outputStrides[j]; + } + + size_t input1Offset = 0; + for (size_t j = 0; j < inputRank; j++) { + input1Offset += index[inputIndices[j]] * inputStrides[j]; + } + + Y[outputOffset] += A[input1Offset]; + + i++; + ShapeUtils::increment_index(index, dimsVector, dimsVector.size()); + } +} diff --git a/src/wasm-ops/einsum.h b/src/wasm-ops/einsum.h new file mode 100644 index 00000000..017845cc --- /dev/null +++ b/src/wasm-ops/einsum.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +extern "C" { +void einsum_f32(void *); +void einsum_f32_imp(const float *A, const float *B, float *Y, + const int32_t *dims, const int32_t rank, + const int32_t *outputIndices, int32_t outputRank, + const int32_t *input1Indices, int32_t input1Rank, + const int32_t *input2Indices, int32_t input2Rank); +void einsum_single_f32(void *); +void einsum_single_f32_imp(const float *A, float *Y, const int32_t *dims, + const int32_t rank, const int32_t *outputIndices, + int32_t outputRank, const int32_t *inputIndices, + int32_t inputRank); +} diff --git a/src/wasm-ops/utils/shape_utils.cpp b/src/wasm-ops/utils/shape_utils.cpp index bd068b83..7d2e9e1d 100644 --- a/src/wasm-ops/utils/shape_utils.cpp +++ b/src/wasm-ops/utils/shape_utils.cpp @@ -86,3 +86,15 @@ void ShapeUtils::offset_to_indices(const std::vector &strides, } indices[indices.size() - 1] = offset; } + +void ShapeUtils::increment_index(std::vector &index, + const std::vector &dims, + size_t axisToIncrementOn) { + for (size_t i = axisToIncrementOn - 1; i >= 0; --i) { + index[i]++; + if (index[i] < dims[i]) { + break; + } + index[i] = 0; + } +} diff --git a/src/wasm-ops/utils/shape_utils.h b/src/wasm-ops/utils/shape_utils.h index 2d138136..9de9dd4b 100644 --- a/src/wasm-ops/utils/shape_utils.h +++ b/src/wasm-ops/utils/shape_utils.h @@ -18,4 +18,7 @@ std::vector offset_to_indices(const std::vector &strides, // Fills in values in the indices vector. Assumes it is of the required size. void offset_to_indices(const std::vector &strides, size_t offset, std::vector &indices); +void increment_index(std::vector &index, + const std::vector &dims, + size_t axisToIncrementOn); }; // namespace ShapeUtils diff --git a/test/data/ops/einsum.jsonc b/test/data/ops/einsum.jsonc new file mode 100644 index 00000000..d51d9c01 --- /dev/null +++ b/test/data/ops/einsum.jsonc @@ -0,0 +1,235 @@ +[ + { + "name": "Einsum batch matmul", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [ + { "name": "equation", "data": "bij, bjk -> bik", "type": "string" } + ], + "cases": [ + { + "name": "BMM", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 2, 2], + "type": "float32" + }, + { + "data": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24 + ], + "dims": [3, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 11, + 14, + 17, + 20, + 23, + 30, + 37, + 44, + 123, + 134, + 145, + 156, + 167, + 182, + 197, + 212, + 363, + 382, + 401, + 420, + 439, + 462, + 485, + 508 + ], + "dims": [3, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum transpose", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [ + { "name": "equation", "data": "ij -> ji", "type": "string" } + ], + "cases": [ + { + "name": "Transpose", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12], + "dims": [4, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum inner product", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "i,i", "type": "string" }], + "cases": [ + { + "name": "Inner product", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [55], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum batch diagonal", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "bii->bi", "type": "string" }], + "cases": [ + { + "name": "Diagonal", + "inputs": [ + { + "data": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18 + ], + "dims": [2, 3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 10, 14, 18], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Einsum sum", + "operator": "Einsum", + "opsets": [ + { + "domain": "", + "version": 12 + } + ], + "attributes": [{ "name": "equation", "data": "ij->i", "type": "string" }], + "cases": [ + { + "name": "Sum", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 26, 42], + "dims": [3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/test/test-suite-whitelist.jsonc b/test/test-suite-whitelist.jsonc index 9a1d2d63..93df3718 100644 --- a/test/test-suite-whitelist.jsonc +++ b/test/test-suite-whitelist.jsonc @@ -225,7 +225,11 @@ "test_xor_bcast4v4d", "test_xor2d", "test_xor3d", - "test_xor4d" + "test_xor4d", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ "abs.jsonc", @@ -265,7 +269,8 @@ "softmax.jsonc", "tan.jsonc", "transpose.jsonc", - "xor.jsonc" + "xor.jsonc", + "einsum.jsonc" ] }, "webgl": { @@ -485,7 +490,11 @@ "test_xor_bcast4v4d", "test_xor2d", "test_xor3d", - "test_xor4d" + "test_xor4d", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ "abs.jsonc", @@ -527,7 +536,8 @@ "sub.jsonc", "tan.jsonc", "transpose.jsonc", - "xor.jsonc" + "xor.jsonc", + "einsum.jsonc" ] }, "wasm": { @@ -624,7 +634,11 @@ "test_globalmaxpool_precomputed", "test_globalmaxpool", "test_instancenorm_epsilon", - "test_instancenorm_example" + "test_instancenorm_example", + "v12/test_einsum_batch_matmul", + "v12/test_einsum_inner_prod", + "v12/test_einsum_sum", + "v12/test_einsum_transpose" ], "ops": [ // Check in op tests that have native Wasm implementations @@ -639,7 +653,8 @@ "and.jsonc", "or.jsonc", "xor.jsonc", - "matmul.jsonc" + "matmul.jsonc", + "einsum.jsonc" ] } }