Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 0c998cc

Browse files
authored
add texture cache support for pack/unpack mode (#290)
* squash change * uncomment test verification * clean up * clean up and remove unused files * adding comment * remove unused code * fix reshape merge * add guid as tensor id * add file * fix test failure
1 parent 290825d commit 0c998cc

File tree

15 files changed

+192
-77
lines changed

15 files changed

+192
-77
lines changed

lib/backends/webgl/glsl-coordinate-lib.ts

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ export class CoordsGlslLib extends GlslLib {
690690
return ${texFuncSnippet}(${unpackedCoordsSnippet});
691691
}
692692
`;
693-
return new GlslLibRoutine(source);
693+
return new GlslLibRoutine(source, ['coordinates.getOutputCoords']);
694694
}
695695

696696
/**
@@ -1216,4 +1216,27 @@ export class CoordsGlslLib extends GlslLib {
12161216
}
12171217
`;
12181218
}
1219+
1220+
/**
1221+
* Produces a packed value getter function for the name and rank given
1222+
* If a transpose is set proper offsetToCoords mapping will be used
1223+
* @param name name of the function
1224+
* @param rank rank of the input
1225+
* @param transpose whether or not should generate a transpose variation
1226+
*/
1227+
protected getPackedValueFrom(varName: string, rank: number, width: number, height: number, transpose: boolean):
1228+
string {
1229+
let name = `_${varName}_Pack`;
1230+
if (transpose) {
1231+
name = name + '_T';
1232+
}
1233+
const glsl = getGlsl(this.context.glContext.version);
1234+
return `
1235+
vec4 ${name}(int m[${rank}]) {
1236+
int offset = indicesToOffset_${varName}(m);
1237+
vec2 coords = offsetToCoords(offset, ${width}, ${height});
1238+
return ${glsl.texture2D}(${varName}, coords);
1239+
}
1240+
`;
1241+
}
12191242
}

lib/backends/webgl/inference-handler.ts

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@ import {Artifact, RunData, TextureData, TextureLayout, WebGLOperator} from './ty
1616
import {getPackedShape} from './utils';
1717

1818
export class WebGLInferenceHandler implements InferenceHandler {
19-
private textureDataCache: Map<Tensor.Id, TextureData>;
19+
private packedTextureDataCache: Map<Tensor.Id, TextureData>;
20+
private unpackedTextureDataCache: Map<Tensor.Id, TextureData>;
21+
private pack2unpackMap: Map<Tensor.Id, Tensor.Id>;
22+
private unpack2packMap: Map<Tensor.Id, Tensor.Id>;
2023
constructor(public session: WebGLSessionHandler) {
21-
this.textureDataCache = new Map();
24+
this.packedTextureDataCache = new Map();
25+
this.unpackedTextureDataCache = new Map();
26+
27+
this.pack2unpackMap = new Map();
28+
this.unpack2packMap = new Map();
2229
}
2330

2431
run(op: WebGLOperator, inputs: Tensor[]): Tensor[] {
@@ -33,36 +40,18 @@ export class WebGLInferenceHandler implements InferenceHandler {
3340
return [runData.outputTextureData.tensor];
3441
}
3542

36-
/**
37-
* Check the runData's input texture mode with the program's artifact.
38-
* If the artifact expects a packed input, while the RunData's input
39-
* is unpacked, perform a pack operation on this input to align the
40-
* texture mode with artifact. Similar on unpacked input.
41-
*/
4243
checkAndUpdateTextureForm(artifact: Artifact, runData: RunData) {
4344
// pack/unpack inputs
44-
runData.inputTextureDatas.forEach(input => {
45+
for (let i = 0; i < runData.inputTextureDatas.length; ++i) {
46+
const input = runData.inputTextureDatas[i];
4547
if (input.isPacked && !artifact.programInfo.expectPackedInputs) {
46-
// unpack this input
47-
const unpacked = this.unpack(input);
48-
input.height = unpacked.height;
49-
input.isPacked = unpacked.isPacked;
50-
input.texture = unpacked.texture;
51-
input.width = unpacked.width;
52-
48+
runData.inputTextureDatas[i] = this.unpack(input);
5349
} else if (!input.isPacked && artifact.programInfo.expectPackedInputs) {
54-
// pack this input
55-
const packed = this.pack(input);
56-
input.height = packed.height;
57-
input.isPacked = packed.isPacked;
58-
input.texture = packed.texture;
59-
input.width = packed.width;
50+
runData.inputTextureDatas[i] = this.pack(input);
6051
}
61-
});
52+
}
6253
}
6354
runProgram(artifact: Artifact, runData: RunData) {
64-
// if the runData has different expected texture pack/unpack mode, process pack/unpack
65-
// operation on the texture before executing the kernel.
6655
this.checkAndUpdateTextureForm(artifact, runData);
6756

6857
// output should match
@@ -84,15 +73,28 @@ export class WebGLInferenceHandler implements InferenceHandler {
8473
* Creates a texture data object associated with the given tensor.
8574
* @param tensor the tensor with data to upload
8675
*/
87-
getOrCreateTextureData(tensor: Tensor, layout?: TextureLayout) {
88-
let td = this.getTextureData(tensor.dataId);
76+
getOrCreateTextureData(tensor: Tensor, layout?: TextureLayout, isPacked = false) {
77+
let td = this.getTextureData(tensor.dataId, isPacked);
8978
if (!td) {
9079
Logger.verbose('InferenceHandler', `Creating new TextureData for dims: [${tensor.dims}]`);
9180
if (!layout) {
9281
layout = this.createTextureLayoutFromShape(tensor.dims.slice());
9382
}
94-
// graph inputs or initializers
95-
td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly);
83+
// if we don't find the texture data with specific pack mode in the cache, try with the different
84+
// pack mode to see if the tensor is cached using that pack mode. If succeed, we can return this
85+
// tensor data and later apply a pack/unpack op on this texture, no need to create a new one here.
86+
td = this.getTextureData(tensor.dataId, !isPacked);
87+
if (!td) {
88+
if (isPacked) {
89+
const unpackedTextureLayout = this.getOrCreateTextureLayout(tensor, 1, false, [], true);
90+
const unpackedTextureData = this.createTextureData(
91+
unpackedTextureLayout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly);
92+
td = this.pack(unpackedTextureData);
93+
} else {
94+
td = this.createTextureData(
95+
layout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly, isPacked);
96+
}
97+
}
9698
} else {
9799
Logger.verbose('InferenceHandler', `Retrieving TextureData from cache: [${tensor.dims}]`);
98100
}
@@ -104,7 +106,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
104106
* Usage = Encoder.Usage.Default.
105107
* @param dataType the tensor data type
106108
*/
107-
createTextureDataFromLayout(layout: TextureLayout, dataType: Tensor.DataType): TextureData {
109+
createTextureDataFromLayout(layout: TextureLayout, dataType: Tensor.DataType, isPacked = false): TextureData {
108110
return this.createTextureData(layout, dataType);
109111
}
110112

@@ -118,13 +120,14 @@ export class WebGLInferenceHandler implements InferenceHandler {
118120
* @param tensor the tensor to bind. tensor's data is ignored.
119121
*/
120122
createTextureDataFromLayoutBindTensor(
121-
layout: TextureLayout, dataType: Tensor.DataType, data: Tensor.NumberType, tensor: Tensor): TextureData {
122-
return this.createTextureData(layout, dataType, data, tensor, Encoder.Usage.UploadOnly);
123+
layout: TextureLayout, dataType: Tensor.DataType, data: Tensor.NumberType, tensor: Tensor,
124+
isPacked = false): TextureData {
125+
return this.createTextureData(layout, dataType, data, tensor, Encoder.Usage.UploadOnly, isPacked);
123126
}
124127

125128
private createTextureData(
126129
layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor,
127-
usage?: Encoder.Usage): TextureData {
130+
usage?: Encoder.Usage, isPacked = false): TextureData {
128131
Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
129132
const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage);
130133
return this.createTextureDataFromTexture(layout, dataType, texture, tensor);
@@ -137,8 +140,9 @@ export class WebGLInferenceHandler implements InferenceHandler {
137140
* @param texture the WebGLTexture object to share
138141
* @param tensorId the tensor ID of the shared tensor data
139142
*/
140-
createSharedTextureData(layout: TextureLayout, dataType: Tensor.DataType, texture: WebGLTexture, tensorId: Tensor.Id):
141-
TextureData {
143+
createSharedTextureData(
144+
layout: TextureLayout, dataType: Tensor.DataType, texture: WebGLTexture, tensorId?: Tensor.Id,
145+
isPacked = false): TextureData {
142146
return this.createTextureDataFromTexture(layout, dataType, texture, undefined, tensorId);
143147
}
144148

@@ -155,29 +159,32 @@ export class WebGLInferenceHandler implements InferenceHandler {
155159
undefined, undefined, tensorId),
156160
texture
157161
};
158-
this.setTextureData(textureData.tensor.dataId, textureData);
162+
this.setTextureData(textureData.tensor.dataId, textureData, layout.isPacked);
159163
return textureData;
160164
}
161165

162-
getTextureData(tensorId: Tensor.Id): TextureData|undefined {
163-
return this.session.isInitializer(tensorId) ? this.session.getTextureData(tensorId) :
164-
this.textureDataCache.get(tensorId);
166+
getTextureData(tensorId: Tensor.Id, isPacked = false): TextureData|undefined {
167+
return this.session.isInitializer(tensorId) ?
168+
this.session.getTextureData(tensorId, isPacked) :
169+
isPacked ? this.packedTextureDataCache.get(tensorId) : this.unpackedTextureDataCache.get(tensorId);
165170
}
166-
setTextureData(tensorId: Tensor.Id, td: TextureData): void {
171+
setTextureData(tensorId: Tensor.Id, td: TextureData, isPacked = false): void {
167172
if (this.session.isInitializer(tensorId)) {
168-
this.session.setTextureData(tensorId, td);
173+
this.session.setTextureData(tensorId, td, isPacked);
169174
} else {
170-
this.textureDataCache.set(tensorId, td);
175+
isPacked ? this.packedTextureDataCache.set(tensorId, td) : this.unpackedTextureDataCache.set(tensorId, td);
171176
}
172177
}
173-
178+
isTextureLayoutCached(tensor: Tensor, isPacked = false): boolean {
179+
return !!this.getTextureData(tensor.dataId, isPacked);
180+
}
174181
/**
175182
* Create a TextureLayout object from a tensor. If a related texture data is found, returns the cached texture layout.
176183
*/
177184
getOrCreateTextureLayout(
178185
tensor: Tensor, channels: 1|4 = 1, isPacked = false, unpackedShape?: ReadonlyArray<number>,
179186
reverseWH = false): TextureLayout {
180-
const td = this.getTextureData(tensor.dataId);
187+
const td = this.getTextureData(tensor.dataId, isPacked);
181188
if (td) {
182189
return td;
183190
}
@@ -229,14 +236,17 @@ export class WebGLInferenceHandler implements InferenceHandler {
229236
isPacked,
230237
shape: inferredDims,
231238
strides: ShapeUtil.computeStrides(inferredDims),
232-
unpackedShape
239+
unpackedShape,
240+
reversedWH: (prefs && prefs.reverseWH)
233241
};
234242
}
235243

236244
dispose(): void {
237245
this.session.textureManager.clearActiveTextures();
238-
this.textureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
239-
this.textureDataCache = new Map();
246+
this.packedTextureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
247+
this.packedTextureDataCache = new Map();
248+
this.unpackedTextureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
249+
this.unpackedTextureDataCache = new Map();
240250
}
241251

242252
readTexture(textureData: TextureData): Tensor.NumberType {
@@ -252,6 +262,10 @@ export class WebGLInferenceHandler implements InferenceHandler {
252262
}
253263

254264
pack(input: TextureData): TextureData {
265+
const cachedId = this.unpack2packMap.get(input.tensor.dataId);
266+
if (cachedId) {
267+
return this.packedTextureDataCache.get(cachedId)!;
268+
}
255269
const key = `${input.shape}`;
256270
let op = this.session.packOpCache.get(key);
257271
if (!op) {
@@ -266,10 +280,15 @@ export class WebGLInferenceHandler implements InferenceHandler {
266280
}
267281
const runData = op.createRunData(this, artifact.programInfo, [input.tensor]);
268282
this.runProgram(artifact, runData);
283+
this.unpack2packMap.set(input.tensor.dataId, runData.outputTextureData.tensor.dataId);
269284
return runData.outputTextureData;
270285
}
271286

272287
unpack(input: TextureData): TextureData {
288+
const cachedId = this.pack2unpackMap.get(input.tensor.dataId);
289+
if (cachedId) {
290+
return this.unpackedTextureDataCache.get(cachedId)!;
291+
}
273292
// For unpacked kernel, cache it by using input's unpackedShape as cache key.
274293
// Note that we need to use input.unpackedShape instead of input.shape here,
275294
// as the shape infers the packed texture shape. Different unpackedShape can have the
@@ -290,6 +309,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
290309
}
291310
const runData = op.createRunData(this, artifact.programInfo, [input.tensor]);
292311
this.runProgram(artifact, runData);
312+
this.pack2unpackMap.set(input.tensor.dataId, runData.outputTextureData.tensor.dataId);
293313
return runData.outputTextureData;
294314
}
295315
}

lib/backends/webgl/ops/binary-op.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ export class WebGLBinaryOp extends BinaryOp implements WebGLOperator {
2222
const inputLayouts = handler.session.pack ?
2323
inputs.map(t => handler.getOrCreateTextureLayout(t, 4, true, t.dims, true)) :
2424
inputs.map(t => handler.getOrCreateTextureLayout(t));
25+
const ouputLayout = handler.session.pack ?
26+
handler.createTextureLayoutFromShape(inputs[0].dims, 4, inputs[0].dims, {isPacked: true, reverseWH: true}) :
27+
handler.createTextureLayoutFromShape(inputs[0].dims);
28+
2529
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
2630
if (isBroadcast) {
2731
const outputShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
@@ -33,6 +37,8 @@ export class WebGLBinaryOp extends BinaryOp implements WebGLOperator {
3337
const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1;
3438
const aBcast = inputs[0].dims.length !== 0 ? `bcastIndices_A(indices, aindices);` : `aindices[0] = 0;`;
3539
const bBcast = inputs[1].dims.length !== 0 ? `bcastIndices_B(indices, bindices);` : `bindices[0] = 0;`;
40+
41+
// TODO: for packed tensors, we need to implement logic to caculate textCoords for broadcast tensor
3642
const shaderSource = `
3743
${this.glslFunc.body}
3844
float process(int indices[${outputRank}]) {
@@ -51,6 +57,8 @@ export class WebGLBinaryOp extends BinaryOp implements WebGLOperator {
5157
outputLayout,
5258
samplers: ['A', 'B'],
5359
shaderSource,
60+
expectPackedInputs: handler.session.pack,
61+
expectPackedOutputs: handler.session.pack
5462
};
5563
}
5664
const glsl = getGlsl(handler.session.backend.glContext.version);
@@ -67,7 +75,7 @@ export class WebGLBinaryOp extends BinaryOp implements WebGLOperator {
6775
return {
6876
hasMain: true,
6977
inputLayouts,
70-
outputLayout: handler.createTextureLayoutFromShape(inputs[0].dims),
78+
outputLayout: ouputLayout,
7179
samplers: ['A', 'B'],
7280
shaderSource,
7381
expectPackedInputs: true,

lib/backends/webgl/ops/reshape-packed.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,17 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
117117
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
118118
const inputTDs =
119119
[handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))];
120+
let outputLayout = this.originalOutputLayout;
121+
if (outputLayout === undefined) {
122+
const originInputShape = inputs[0].dims;
123+
const outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
124+
outputLayout =
125+
handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true});
126+
}
120127
// return run data for reshape. Here, we use the original calculate outputLayout to create the real output layout.
121128
return {
122129
inputTextureDatas: inputTDs,
123-
outputTextureData: handler.createTextureDataFromLayout(this.originalOutputLayout, inputTDs[0].tensor.type),
130+
outputTextureData: handler.createTextureDataFromLayout(outputLayout, inputTDs[0].tensor.type),
124131
uniformData: {}
125132
};
126133
}

lib/backends/webgl/ops/reshape.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ export function reshape(
4343
unpackedShape: reshapedDims,
4444
};
4545

46-
const newTextureData = inferenceHandler.createSharedTextureData(newTextureLayout, input.type, inputTD.texture, {});
46+
const newTextureData = inferenceHandler.createSharedTextureData(newTextureLayout, input.type, inputTD.texture);
4747
return newTextureData.tensor;
4848
}

lib/backends/webgl/ops/uint8-encode.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ export class WebGLUint8Encode {
7777
const encoder = inferenceHandler.session.backend.glContext.getEncoder('byte', 4);
7878
const texture =
7979
inferenceHandler.session.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, encoder);
80-
const outputTextureData = inferenceHandler.createSharedTextureData(outputLayout, 'uint8', texture, {});
80+
const outputTextureData = inferenceHandler.createSharedTextureData(outputLayout, 'uint8', texture);
8181
const runData = {inputTextureDatas: [input], outputTextureData, uniformData: {}};
8282

8383
inferenceHandler.session.programManager.run(artifact, runData);

lib/backends/webgl/ops/unpack.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {getGlsl} from '../glsl-source';
66
import {WebGLInferenceHandler} from '../inference-handler';
77
import {ProgramInfo, RunData, WebGLOperator} from '../types';
88
import {getCoordsDataType} from '../utils';
9-
109
import {getChannels, unpackFromChannel} from './packing_utils';
1110

1211
export class WebGLUnpack implements WebGLOperator {
@@ -18,7 +17,7 @@ export class WebGLUnpack implements WebGLOperator {
1817
throw new Error(`Pack kernel should have input tensor count to 1.`);
1918
}
2019

21-
const inputTexture = handler.getTextureData(inputs[0].dataId);
20+
const inputTexture = handler.getTextureData(inputs[0].dataId, true);
2221
if (!inputTexture) {
2322
throw new Error(`packed input texture must exist`);
2423
}
@@ -49,7 +48,7 @@ export class WebGLUnpack implements WebGLOperator {
4948
`;
5049

5150
return {
52-
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
51+
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0], 4, true, inputs[0].dims, true)],
5352
outputLayout,
5453
samplers: ['A'],
5554
shaderSource,
@@ -59,7 +58,7 @@ export class WebGLUnpack implements WebGLOperator {
5958
};
6059
}
6160
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
62-
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
61+
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0], true)];
6362
return {
6463
inputTextureDatas: inputTDs,
6564
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),

lib/backends/webgl/program-manager.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,12 @@ ${fragShaderScript}
105105
return program;
106106
}
107107
bindOutput(td: TextureData): void {
108+
const width = td.width;
109+
const height = td.height;
108110
Logger.verbose(
109111
'ProrgramManager',
110-
`Binding output texture to Framebuffer: w/h=${td.width}/${td.height}, shape=${td.shape}, type=${
111-
td.tensor.type}`);
112-
this.glContext.attachFramebuffer(td.texture, td.width, td.height);
112+
`Binding output texture to Framebuffer: w/h=${width}/${height}, shape=${td.shape}, type=${td.tensor.type}`);
113+
this.glContext.attachFramebuffer(td.texture, width, height);
113114
}
114115
bindAttributes(attribLocations: Artifact.AttribLocations): void {
115116
const positionHandle = attribLocations.position;

0 commit comments

Comments
 (0)