diff --git a/packages/paddlejs-backend-webgl/src/ops/index.ts b/packages/paddlejs-backend-webgl/src/ops/index.ts index 0fee2cdb..66195dd4 100644 --- a/packages/paddlejs-backend-webgl/src/ops/index.ts +++ b/packages/paddlejs-backend-webgl/src/ops/index.ts @@ -64,6 +64,10 @@ import { imgFeed, pack_out, nhwc_2_nchw, unpacked_2_packed, packed_2_unpacked, feedPost } from './shader/custom'; +import connect_mul from './shader/connect_mul'; +import instancenorm from './shader/instancenorm'; +import instancenorm_variance from './shader/instancenorm_variance'; +import instancenorm_mean from './shader/instancenorm_mean'; const ops = { @@ -140,7 +144,12 @@ const ops = { density_prior_box, prior_box, stack, - slice + slice, + 'conv2d-elementwise_add-leaky_relu': conv2d_elementwise_add, + connect_mul, + instance_norm: instancenorm, + instance_norm_mean: instancenorm_mean, + instance_norm_variance: instancenorm_variance }; export { ops diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/batchnorm.ts b/packages/paddlejs-backend-webgl/src/ops/shader/batchnorm.ts index 996e1f53..ee20fcba 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/batchnorm.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/batchnorm.ts @@ -14,11 +14,11 @@ function mainFunc( float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); // 归一化数据 - vec4 scale = getPixelsFromTexturePos_scale(vec2( float(oPos.g) / float(${scale.width_texture}) + 0.00001, 0.0)); - vec4 bias = getPixelsFromTexturePos_bias(vec2( float(oPos.g) / float(${bias.width_texture}) + 0.00001, 0.0)); - vec4 mean = getPixelsFromTexturePos_mean(vec2((float(oPos.g)) / float(${mean.width_texture}) + 0.00001, 0.0)); + vec4 scale = getPixelsFromTexturePos_scale(vec2(float(oPos.g) / float(${scale.width_texture}) + 0.00001, 0.0)); + vec4 bias = getPixelsFromTexturePos_bias(vec2(float(oPos.g) / float(${bias.width_texture}) + 0.00001, 0.0)); + vec4 mean = getPixelsFromTexturePos_mean(vec2(float(oPos.g) / float(${mean.width_texture}) + 0.00001, 0.0)); vec4 variance = getPixelsFromTexturePos_variance( - vec2((float(oPos.g)) / float(${variance.width_texture}) + 0.00001, + vec2(float(oPos.g) / float(${variance.width_texture}) + 0.00001, 0.0) ); diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/connect_mul.ts b/packages/paddlejs-backend-webgl/src/ops/shader/connect_mul.ts new file mode 100644 index 00000000..abbd2fe9 --- /dev/null +++ b/packages/paddlejs-backend-webgl/src/ops/shader/connect_mul.ts @@ -0,0 +1,53 @@ + +/** + * @file concat + */ + +import { reduceShape } from '../../utils/dataProcess'; + +function mainFunc( + { origin, counter, out }, + {} +) { + const { total_shape, width_shape, height_shape, channel } = out; + const reducedShape = reduceShape([ + total_shape / (width_shape * height_shape * channel), + channel, + height_shape, + width_shape + ]); + return ` + // start函数 + void main(void) { + ivec4 oPos = getOutputTensorPos(); + float o = 0.0; + ivec4 co; + int sumVal = oPos.b * ${reducedShape[2]} + oPos.a; + if (sumVal < ${origin.total_shape}) { + // from origin + co = getTensorPosFromArrayIndex_origin(sumVal); + o = getValueFromTensorPos_origin(co.r, co.g, co.b, co.a); + + } + else if (sumVal > ${origin.total_shape} && sumVal < ${origin.total_shape + counter.total_shape}) { + co = getTensorPosFromArrayIndex_counter(sumVal - ${origin.total_shape}); + o = getValueFromTensorPos_counter(co.r, co.g, co.b, co.a); + } + else { + // from appender + co = getTensorPosFromArrayIndex_appender(sumVal - ${origin.total_shape + counter.total_shape}); + o = getValueFromTensorPos_appender(co.r, co.g, co.b, co.a); + } + setOutput(float(o)); + } + `; +} +export default { + mainFunc, + params: [], + textureFuncConf: { + origin: ['getValueFromTensorPos', 'getTensorPosFromArrayIndex'], + counter: ['getValueFromTensorPos', 'getTensorPosFromArrayIndex'], + appender: ['getValueFromTensorPos', 'getTensorPosFromArrayIndex'] + } +}; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts b/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts index d6a63896..2c2ca25f 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts @@ -3,19 +3,16 @@ * @file greater_than return x >= y */ -function mainFunc( - {}, - {} -) { +function mainFunc() { return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 - float x = getValueFromTensorPos_input(oPos.r, oPos.g, oPos.b, oPos.a); + float x = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); float y = getValueFromTensorPos_counter(oPos.r, oPos.g, oPos.b, oPos.a); - setOutput(bool(x >= y)); + setOutput(float(bool(x >= y))); } `; } diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm.ts b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm.ts new file mode 100644 index 00000000..61efab96 --- /dev/null +++ b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm.ts @@ -0,0 +1,42 @@ +/** + * @file batchnorm + */ + +function mainFunc( + { bias, scale }, + { } +) { + return ` + + // start函数 + void main(void) { + // 输出数据 + ivec4 oPos = getOutputTensorPos(); + float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); + + // 归一化数据 + vec4 scale = getPixelsFromTexturePos_scale(vec2(float(oPos.g) / float(${scale.width_texture}) + 0.00001, 0.0)); + vec4 bias = getPixelsFromTexturePos_bias(vec2(float(oPos.g) / float(${bias.width_texture}) + 0.00001, 0.0)); + float mean = getValueFromTensorPos_mean(0, 0, oPos.r, oPos.g); + float variance = getValueFromTensorPos_variance(0, 0, oPos.r, oPos.g); + + float res = (o - mean) * variance; + // setOutput(res); + + setOutput(res); + } + `; +} +export default { + mainFunc, + params: [ + 'epsilon' + ], + textureFuncConf: { + origin: ['getValueFromTensorPos'], + scale: ['getPixelsFromTexturePos'], + bias: ['getPixelsFromTexturePos'], + mean: ['getValueFromTensorPos'], + variance: ['getValueFromTensorPos'] + } +}; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_mean.ts b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_mean.ts new file mode 100644 index 00000000..3a1e7bfb --- /dev/null +++ b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_mean.ts @@ -0,0 +1,39 @@ +/** + * @file batchnorm + */ + +function mainFunc( + { origin }, + {} +) { + const { height_shape, width_shape } = origin; + + return ` + + // start函数 + void main(void) { + // 输出数据 + ivec4 oPos = getOutputTensorPos(); + float o = 0.0; + for (int i = 0; i < ${height_shape}; i++) { + float inner = 0.0; + for (int j = 0; j < ${width_shape}; j++) { + inner += getValueFromTensorPos_origin(oPos.b, oPos.a, i, j); + } + + o += (inner / float(${width_shape})); + } + + o = o / float(${height_shape}); + setOutput(o); + } + `; +} +export default { + mainFunc, + params: [ + ], + textureFuncConf: { + origin: ['getValueFromTensorPos'] + } +}; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_variance.ts b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_variance.ts new file mode 100644 index 00000000..b6a300e0 --- /dev/null +++ b/packages/paddlejs-backend-webgl/src/ops/shader/instancenorm_variance.ts @@ -0,0 +1,46 @@ +/** + * @file batchnorm + */ + +function mainFunc( + { origin }, + { epsilon } +) { + const { height_shape, width_shape } = origin; + + return ` + + // start函数 + void main(void) { + // 输出数据 + ivec4 oPos = getOutputTensorPos(); + + float variance = 0.0; + float sum = 0.0; + for (int i = 0; i < ${height_shape}; i++) { + float inner = 0.0; + for (int j = 0; j < ${width_shape}; j++) { + float o = getValueFromTensorPos_origin(oPos.b, oPos.a, i, j); + float m = getValueFromTensorPos_mean(oPos.r, oPos.g, oPos.b, oPos.a); + float diff = o - m; + inner += diff * diff; + } + + sum += inner / float(${width_shape}); + } + variance = 1.0 / sqrt(sum / float(${height_shape}) + ${epsilon}); + + setOutput(variance); + } + `; +} +export default { + mainFunc, + params: [ + 'epsilon' + ], + textureFuncConf: { + origin: ['getValueFromTensorPos'], + mean: ['getValueFromTensorPos'] + } +}; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts index 7430c777..415e3ee0 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts @@ -3,31 +3,81 @@ * @file reduce_mean */ +function getGcd(a, b) { + const max = Math.max(a, b); + const min = Math.min(a, b); + if (max % min === 0) { + return min; + } + return getGcd(max % min, min); +} + +function getLcm(a, b) { + return a * b / getGcd(a, b); +} function mainFunc( - {}, - { inputs_dim, dim } + { origin }, + { dim } ) { + const { total_shape, height_shape, width_shape, channel } = origin; + const batch_shape = total_shape / (width_shape * height_shape * channel); + const shape = [batch_shape, channel, height_shape, width_shape]; + let dimArr = []; + if (dim instanceof Array) { + dimArr = dim; + } + else { + dimArr.push(dim); + } + const dimShape = dimArr.map(item => shape[item]); + const totalDimShape = dimShape.reduce((prev, cur) => prev * cur); + const arrGcd = dimShape.reduce((prev, cur) => getLcm(prev, cur)); + const remainV = totalDimShape / arrGcd; + + let codeStr = 'float sum = 0.0;'; + const strArr = [` + sum += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a) / float(${arrGcd}); + `]; + for (let i = 0; i < dimArr.length; i++) { + const curDim = dimArr[i]; + const curDimShape = shape[dimArr[i]]; + const vname = `i${i}`; + strArr.unshift(` + for (int ${vname} = 0; ${vname} < ${curDimShape}; ${vname}++) { + oPos[${curDim}] = ${vname}; + `); + strArr.push( + ` + } + ` + ); + } + + codeStr += strArr.join('\n'); + codeStr += ` + o = sum / float(${remainV}); + `; + return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 float o = 0.0; - for (int i = 0; i < ${inputs_dim}; i++) { - oPos[${dim}] = i; - o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); - } - o = o / float(${inputs_dim}); + ${codeStr} setOutput(o); } `; } export default { mainFunc, + params: [ + 'dim' + ], textureFuncConf: { origin: ['getValueFromTensorPos'] }, behaviors: [ - 'normalizeDim' + ] }; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts index 74fb87e2..4d9fbae2 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts @@ -1,32 +1,64 @@ /** - * @file concat + * @file reduce_sum */ function mainFunc( - {}, - { inputs_dim, dim } + { origin }, + { dim } ) { + const { total_shape, height_shape, width_shape, channel } = origin; + const batch_shape = total_shape / (width_shape * height_shape * channel); + const shape = [batch_shape, channel, height_shape, width_shape]; + let dimArr = []; + if (dim instanceof Array) { + dimArr = dim; + } + else { + dimArr.push(dim); + } + + let codeStr = ''; + const strArr = [` + o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); + `]; + for (let i = 0; i < dimArr.length; i++) { + const curDim = dimArr[i]; + const curDimShape = shape[dimArr[i]]; + const vname = `i${i}`; + strArr.unshift(` + for (int ${vname} = 0; ${vname} < ${curDimShape}; ${vname}++) { + oPos[${curDim}] = ${vname}; + `); + strArr.push( + ` + } + ` + ); + } + + codeStr += strArr.join('\n'); + return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 float o = 0.0; - for (int i = 0; i < ${inputs_dim}; i++) { - oPos[${dim}] = i; - o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);; - } - setOutput(float(o)); + ${codeStr} + setOutput(o); } `; } export default { mainFunc, + params: [ + 'dim' + ], textureFuncConf: { origin: ['getValueFromTensorPos'] }, behaviors: [ - 'normalizeDim' + ] }; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/where.ts b/packages/paddlejs-backend-webgl/src/ops/shader/where.ts index 7f7e65a4..59b6a84c 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/where.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/where.ts @@ -3,16 +3,13 @@ * @file where return condition ? x : y */ -function mainFunc( - {}, - {} -) { +function mainFunc() { return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 - float x = getValueFromTensorPos_input(oPos.r, oPos.g, oPos.b, oPos.a); + float x = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); float y = getValueFromTensorPos_counter(oPos.r, oPos.g, oPos.b, oPos.a); float condition = getValueFromTensorPos_condition(oPos.r, oPos.g, oPos.b, oPos.a); float o = 0.0; diff --git a/packages/paddlejs-backend-webgl/src/ops/utils.ts b/packages/paddlejs-backend-webgl/src/ops/utils.ts index 39974166..14981a3b 100644 --- a/packages/paddlejs-backend-webgl/src/ops/utils.ts +++ b/packages/paddlejs-backend-webgl/src/ops/utils.ts @@ -20,7 +20,11 @@ const tensorParams = [ 'offset_y', 'channel', 'total_shape', - 'numbers_shape' + 'numbers_shape', + 'scale', + 'bias', + 'mean', + 'variance' ]; diff --git a/packages/paddlejs-converter/convertModel.py b/packages/paddlejs-converter/convertModel.py index 098079a1..77e06e59 100644 --- a/packages/paddlejs-converter/convertModel.py +++ b/packages/paddlejs-converter/convertModel.py @@ -70,6 +70,17 @@ def validateShape(shape, name): if len(shape) > 4: newShape = shape[-4:] print('\033[31m ' + name + ' tensor shape length > 4, 处理为丢弃头部shape \033[0m') + for index, value in enumerate(modelInfo["ops"]): + if 'X' in value['inputs'].keys(): + # squeeze2 axes纠正 + if value['type'] == 'squeeze2' and name == value['inputs']['X'][0] and 'axes' in value['attrs'].keys(): + modelInfo["ops"][index]['attrs']['axes'][0] = modelInfo["ops"][index]['attrs']['axes'][0] - 1 + # transpose2 axis纠正 + if value['type'] == 'transpose2' and name == value['inputs']['X'][0] and 'axis'in value['attrs'].keys(): + for i, v in enumerate(value['attrs']['axis']): + if i > 0: + modelInfo["ops"][index]['attrs']['axis'][i] = v - 1 + modelInfo["ops"][index]['attrs']['axis'] = modelInfo["ops"][index]['attrs']['axis'][-4:] return newShape return shape diff --git a/packages/paddlejs-core/src/graph.ts b/packages/paddlejs-core/src/graph.ts index c7ea553f..f357e376 100644 --- a/packages/paddlejs-core/src/graph.ts +++ b/packages/paddlejs-core/src/graph.ts @@ -101,27 +101,34 @@ export default class ModelGraph { } private arrangeMap() { - const executed: object = {}; const inIndex: number[] = []; const idtoindex: object = {}; + const indexMap: object = {}; for (let index = 0; index < this.weightMap.length; index++) { const item = this.weightMap[index]; - for (let index = 0; index < item.outputsName.length; index++) { - const outputName = item.outputsName[index]; - executed[outputName] = true; + for (let i = 0; i < item.outputsName.length; i++) { + const outputName = item.outputsName[i]; + if (!indexMap[outputName]) { + indexMap[outputName] = 1; + } + else { + indexMap[outputName]++; + } + idtoindex[item.id] = index; } + } + + for (let index = 0; index < this.weightMap.length; index++) { + const item = this.weightMap[index]; + inIndex[index] = 0; - idtoindex[item.id] = index; - if (item.inputsName.length > 1) { - item.inputsName.forEach(i => { - if (executed[i] === true) { - inIndex[index]++; - } - }); - } - else { - inIndex[index] = item.inputsName.length; + + for (let i = 0; i < item.inputsName.length; i++) { + const inputName = item.inputsName[i]; + if (indexMap[inputName]) { + inIndex[index]++; + } } } @@ -129,15 +136,37 @@ export default class ModelGraph { } private topoSort(ops: OpExecutor[], inIndex: number[], idtoindex: object) { - const inline: OpExecutor[] = []; - inline.push(ops[0]); - const ops_temp = ops.slice(0); - let prev: OpExecutor = null; - let iterator: OpExecutor = ops[0]; + const zeroIndexList = []; + const bpOp = ops.slice(0); + for (let i = 0; i < inIndex.length; i++) { + if (inIndex[i] === 0) { + zeroIndexList.push(ops[i]); + } + } + + let preOp = null; + for (const curOp of zeroIndexList) { + preOp = this.topoSortInner(ops, inIndex, idtoindex, curOp, bpOp, preOp); + } + } + + private topoSortInner( + ops: OpExecutor[], + inIndex: number[], + idtoindex: object, + curOp: OpExecutor, + ops_temp: OpExecutor[], + preOp?: OpExecutor + ) { + const inline = [curOp]; + let prev: OpExecutor = preOp; + let iterator: OpExecutor = curOp; + while (inline.length > 0) { - if (prev != null) { + if (prev) { ops[idtoindex[prev.id]].next = iterator.id; } + prev = iterator; iterator = inline.pop() || {} as OpExecutor; @@ -157,6 +186,9 @@ export default class ModelGraph { } } } + + ops[idtoindex[prev.id]].next = iterator.id; + return iterator; } /** diff --git a/packages/paddlejs-core/src/opFactory/opDataBuilder.ts b/packages/paddlejs-core/src/opFactory/opDataBuilder.ts index 36bbbd75..df1034e2 100644 --- a/packages/paddlejs-core/src/opFactory/opDataBuilder.ts +++ b/packages/paddlejs-core/src/opFactory/opDataBuilder.ts @@ -105,7 +105,9 @@ export default class OpData { const tensorName = this.getExactTensorName(key, 'input'); if (tensorName) { const tensor = data[0]; - tensor.tensorName = tensorName; + if (tensor) { + tensor.tensorName = tensorName; + } this.tensorDataMap[tensorName] = { ...tensor, tensorName @@ -187,6 +189,16 @@ export default class OpData { this.name = 'pool2d_max'; } + else if (this.name.indexOf('sync_batch_norm') > -1) { + this.name = 'batchnorm'; + } + else if (this.name.indexOf('bilinear_interp_v2') > -1) { + this.name = 'bilinear_interp'; + } + else if (this.name.indexOf('leaky_relu') > -1) { + this.name = 'conv2d_elementwise_add'; + } + // unique behavior const opKey = `${GLOBALS.backend}_${this.name}`; const behaviorKeys = GLOBALS.opRegistry.ops[opKey] diff --git a/packages/paddlejs-core/src/transform/formatInstanceNorm.ts b/packages/paddlejs-core/src/transform/formatInstanceNorm.ts new file mode 100644 index 00000000..5b4d74ed --- /dev/null +++ b/packages/paddlejs-core/src/transform/formatInstanceNorm.ts @@ -0,0 +1,54 @@ + +import Transformer from './transformer'; +import { ModelOp } from '../commons/interface'; + +export default class FormateInstanceNorm extends Transformer { + constructor() { + super('formateInstanceNorm'); + } + + transform(...args: any) { + const [ops] = args; + + // the array length > 4 of inputs.X + for (let index = 0, len = ops.length; index < len; index++) { + const opInfo = ops[index]; + if (opInfo.type === 'instance_norm') { + // 取到input mean vavariance key + const { X, Variance, Mean } = opInfo.inputs; + + // 构造instance_norm_mean 和instance_norm_variance 算子 + const instanceNormMeanOp: ModelOp = { + attrs: { + }, + inputs: { + X: X + }, + outputs: { + Out: [Mean[0]] + }, + type: 'instance_norm_mean' + }; + + const instanceNormVarianceOp: ModelOp = { + attrs: { + epsilon: opInfo.attrs.epsilon || 0.000009999999747378752 + }, + inputs: { + X: X, + Mean: [Mean[0]] + }, + outputs: { + Out: [Variance[0]] + }, + type: 'instance_norm_variance' + }; + + ops.splice(index, 0, instanceNormVarianceOp); + ops.splice(index, 0, instanceNormMeanOp); + index += 2; + len += 2; + } + } + } +} \ No newline at end of file diff --git a/packages/paddlejs-core/src/transform/index.ts b/packages/paddlejs-core/src/transform/index.ts index 45b47e34..0a8ad929 100644 --- a/packages/paddlejs-core/src/transform/index.ts +++ b/packages/paddlejs-core/src/transform/index.ts @@ -4,6 +4,8 @@ import type Transformer from './transformer'; import SplitOp from './splitOp'; import PackOutOp from './packOutOp'; import FeedProcess from './feedProcess'; +import OptModel from './optModel'; +import FormateInstanceNorm from './formatInstanceNorm'; interface TransformerAction { preTransforms: Transformer[]; @@ -14,8 +16,10 @@ interface TransformerAction { const actions: TransformerAction = { preTransforms: [ new SplitOp(), + new FormateInstanceNorm(), new PackOutOp(), - new FeedProcess() + new FeedProcess(), + new OptModel() ], transforms: [ new FormatInputsX(), diff --git a/packages/paddlejs-core/src/transform/optModel.ts b/packages/paddlejs-core/src/transform/optModel.ts new file mode 100644 index 00000000..7256ccdc --- /dev/null +++ b/packages/paddlejs-core/src/transform/optModel.ts @@ -0,0 +1,50 @@ +import { findVarByKey } from '../commons/utils'; +import Transformer from './transformer'; + + +export default class OptModel extends Transformer { + constructor() { + super('OptModel'); + } + + transform(...args: any) { + const [ops, vars] = args; + for (let opIndex = 0; opIndex < ops.length; opIndex++) { + const op = ops[opIndex]; + + // todo:这样取可能有问题 + let curInputName = ''; + if ((op.type === 'size' && ops[opIndex + 1].type === 'cast') || op.type === 'shape') { + curInputName = op.inputs.Input + ? op.inputs.Input[0] + : op.inputs.X[0]; + } + + // 当size、cast算子连续使用的时候,把size和cast算子剔除,直接计算cast的输出 + if (op.type === 'size' && ops[opIndex + 1].type === 'cast') { + const currentOpName = ops[opIndex + 1].outputs.Out[0]; + const inputVar = findVarByKey(vars, curInputName); + const total = inputVar.shape.reduce((pre, cur) => pre * cur); + const outputVar = findVarByKey(vars, currentOpName); + outputVar.data = [total]; + outputVar.persistable = true; + ops.splice(opIndex, 1); + ops.splice(opIndex, 1); + opIndex = opIndex - 1; + } + else if (op.type === 'shape') { + const curOutputName = op.outputs.Out[0]; + const inputVar = findVarByKey(vars, curInputName); + const outputVar = findVarByKey(vars, curOutputName); + const inputShape = inputVar.shape; + outputVar.data = inputShape; + outputVar.persistable = true; + + ops.splice(opIndex, 1); + opIndex = opIndex - 1; + } + + } + } + +}