Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
operator: add Flatten ops for webgl and cpu backends (#93)
Browse files Browse the repository at this point in the history
* add Flatten ops for webgl and cpu backends

* use flattenShape function for both cpu and webgl backends

* strictly following onnx specs and checking for scalar tensor
  • Loading branch information
NTT123 authored and fs-eire committed Feb 28, 2019
1 parent 187779e commit 7ea55fa
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/backends/cpu/ops-resolve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {CpuBinaryOp} from './ops/binary-op';
import {CpuConcat} from './ops/concat';
import {CpuConv} from './ops/conv';
import {CpuDropout} from './ops/dropout';
import {CpuFlatten} from './ops/flatten';
import {CpuGather} from './ops/gather';
import {CpuGemm} from './ops/gemm';
import {CpuImageScaler} from './ops/image-scaler';
Expand Down Expand Up @@ -98,6 +99,8 @@ function createOperator(node: Graph.Node, domain: string, version: number): Oper
return new CpuConv();
case 'Dropout':
return new CpuDropout();
case 'Flatten':
return new CpuFlatten();
case 'Gemm':
return new CpuGemm();
case 'ImageScaler':
Expand Down
26 changes: 26 additions & 0 deletions lib/backends/cpu/ops/flatten.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Flatten} from '../../../ops/flatten';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {CpuInferenceHandler} from '../inference-handler';

export class CpuFlatten extends Flatten {
run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
const output = flatten(inputs[0], this.axis);
return [output];
}
}

export function flatten(x: Tensor, axis: number): Tensor {
const outputDims = ShapeUtil.flattenShape(x.dims, axis);
const output = new Tensor(outputDims, x.type);

const X = x.numberData;
const Y = output.numberData;

Y.set(X);

return output;
}
17 changes: 17 additions & 0 deletions lib/backends/webgl/ops/flatten.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Flatten} from '../../../ops/flatten';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';

import {reshape} from './reshape';

export class WebGLFlatten extends Flatten {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const outputDims = ShapeUtil.flattenShape(inputs[0].dims, this.axis);

return [reshape(inferenceHandler, inputs[0], outputDims)];
}
}
3 changes: 3 additions & 0 deletions lib/backends/webgl/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import * as binaryOps from './ops/binary-op';
import {WebGLConcat} from './ops/concat';
import {WebGLConv} from './ops/conv';
import {WebGLDropout} from './ops/dropout';
import {WebGLFlatten} from './ops/flatten';
import {WebGLGather} from './ops/gather';
import {WebGLGemm} from './ops/gemm';
import {WebGLImageScaler} from './ops/image-scaler';
Expand Down Expand Up @@ -123,6 +124,8 @@ export class WebGLSessionHandler implements SessionHandler {
return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool');
case 'Exp':
return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp());
case 'Flatten':
return new WebGLFlatten();
case 'Floor':
return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor());
case 'Gather':
Expand Down
42 changes: 42 additions & 0 deletions lib/ops/flatten.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

export abstract class Flatten implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {
this.axis = attributes.getInt('axis', 1); // default axis is 1
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 1) {
return false;
}

if (inputs[0].dims.length === 0) {
return false; // scalar tensor is not supported
}

if (this.axis < 0 || this.axis > inputs[0].dims.length) {
return false;
}

return this.checkInputTypes(inputs);
}

protected checkInputTypes(inputs: Tensor[]): boolean {
// TODO: Support string type
if (inputs[0].type === 'string') {
return false;
}

return true;
}

protected axis: number;
}
13 changes: 13 additions & 0 deletions lib/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,19 @@ export class ShapeUtil {
return size;
}

/**
* Determines the shape of output tensor y = flatten(x, axis)
* @param dims - shape of input tensor
* @param axis - flatten axis
*/
static flattenShape(dims: ReadonlyArray<number>, axis: number): ReadonlyArray<number> {
const total = dims.reduce((x, y) => x * y, 1);
const right = dims.slice(axis).reduce((x, y) => x * y, 1);
const outputDims = [total / right, right];

return outputDims;
}

/**
* Determines the shape of output tensor y = squeeze(x, axes)
* @param dims - shape of input tensor
Expand Down
10 changes: 10 additions & 0 deletions test/unittest-whitelist.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
"test_div",
"test_dropout_default",
"test_dropout_random",
"test_flatten_axis0",
"test_flatten_axis1",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_default_axis",
"test_gather_0",
"test_gather_1",
"test_gemm_broadcast",
Expand Down Expand Up @@ -286,6 +291,11 @@
"test_div",
"test_dropout_default",
"test_dropout_random",
"test_flatten_axis0",
"test_flatten_axis1",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_default_axis",
"test_gather_0",
"test_gather_1",
"test_gemm_nobroadcast",
Expand Down

0 comments on commit 7ea55fa

Please sign in to comment.