@@ -16,9 +16,16 @@ import {Artifact, RunData, TextureData, TextureLayout, WebGLOperator} from './ty
1616import { getPackedShape } from './utils' ;
1717
1818export class WebGLInferenceHandler implements InferenceHandler {
19- private textureDataCache : Map < Tensor . Id , TextureData > ;
19+ private packedTextureDataCache : Map < Tensor . Id , TextureData > ;
20+ private unpackedTextureDataCache : Map < Tensor . Id , TextureData > ;
21+ private pack2unpackMap : Map < Tensor . Id , Tensor . Id > ;
22+ private unpack2packMap : Map < Tensor . Id , Tensor . Id > ;
2023 constructor ( public session : WebGLSessionHandler ) {
21- this . textureDataCache = new Map ( ) ;
24+ this . packedTextureDataCache = new Map ( ) ;
25+ this . unpackedTextureDataCache = new Map ( ) ;
26+
27+ this . pack2unpackMap = new Map ( ) ;
28+ this . unpack2packMap = new Map ( ) ;
2229 }
2330
2431 run ( op : WebGLOperator , inputs : Tensor [ ] ) : Tensor [ ] {
@@ -33,36 +40,18 @@ export class WebGLInferenceHandler implements InferenceHandler {
3340 return [ runData . outputTextureData . tensor ] ;
3441 }
3542
36- /**
37- * Check the runData's input texture mode with the program's artifact.
38- * If the artifact expects a packed input, while the RunData's input
39- * is unpacked, perform a pack operation on this input to align the
40- * texture mode with artifact. Similar on unpacked input.
41- */
4243 checkAndUpdateTextureForm ( artifact : Artifact , runData : RunData ) {
4344 // pack/unpack inputs
44- runData . inputTextureDatas . forEach ( input => {
45+ for ( let i = 0 ; i < runData . inputTextureDatas . length ; ++ i ) {
46+ const input = runData . inputTextureDatas [ i ] ;
4547 if ( input . isPacked && ! artifact . programInfo . expectPackedInputs ) {
46- // unpack this input
47- const unpacked = this . unpack ( input ) ;
48- input . height = unpacked . height ;
49- input . isPacked = unpacked . isPacked ;
50- input . texture = unpacked . texture ;
51- input . width = unpacked . width ;
52-
48+ runData . inputTextureDatas [ i ] = this . unpack ( input ) ;
5349 } else if ( ! input . isPacked && artifact . programInfo . expectPackedInputs ) {
54- // pack this input
55- const packed = this . pack ( input ) ;
56- input . height = packed . height ;
57- input . isPacked = packed . isPacked ;
58- input . texture = packed . texture ;
59- input . width = packed . width ;
50+ runData . inputTextureDatas [ i ] = this . pack ( input ) ;
6051 }
61- } ) ;
52+ }
6253 }
6354 runProgram ( artifact : Artifact , runData : RunData ) {
64- // if the runData has different expected texture pack/unpack mode, process pack/unpack
65- // operation on the texture before executing the kernel.
6655 this . checkAndUpdateTextureForm ( artifact , runData ) ;
6756
6857 // output should match
@@ -84,15 +73,28 @@ export class WebGLInferenceHandler implements InferenceHandler {
8473 * Creates a texture data object associated with the given tensor.
8574 * @param tensor the tensor with data to upload
8675 */
87- getOrCreateTextureData ( tensor : Tensor , layout ?: TextureLayout ) {
88- let td = this . getTextureData ( tensor . dataId ) ;
76+ getOrCreateTextureData ( tensor : Tensor , layout ?: TextureLayout , isPacked = false ) {
77+ let td = this . getTextureData ( tensor . dataId , isPacked ) ;
8978 if ( ! td ) {
9079 Logger . verbose ( 'InferenceHandler' , `Creating new TextureData for dims: [${ tensor . dims } ]` ) ;
9180 if ( ! layout ) {
9281 layout = this . createTextureLayoutFromShape ( tensor . dims . slice ( ) ) ;
9382 }
94- // graph inputs or initializers
95- td = this . createTextureData ( layout , tensor . type , tensor . numberData , tensor , Encoder . Usage . UploadOnly ) ;
83+ // if we don't find the texture data with specific pack mode in the cache, try with the different
84+ // pack mode to see if the tensor is cached using that pack mode. If succeed, we can return this
85+ // tensor data and later apply a pack/unpack op on this texture, no need to create a new one here.
86+ td = this . getTextureData ( tensor . dataId , ! isPacked ) ;
87+ if ( ! td ) {
88+ if ( isPacked ) {
89+ const unpackedTextureLayout = this . getOrCreateTextureLayout ( tensor , 1 , false , [ ] , true ) ;
90+ const unpackedTextureData = this . createTextureData (
91+ unpackedTextureLayout , tensor . type , tensor . numberData , tensor , Encoder . Usage . UploadOnly ) ;
92+ td = this . pack ( unpackedTextureData ) ;
93+ } else {
94+ td = this . createTextureData (
95+ layout , tensor . type , tensor . numberData , tensor , Encoder . Usage . UploadOnly , isPacked ) ;
96+ }
97+ }
9698 } else {
9799 Logger . verbose ( 'InferenceHandler' , `Retrieving TextureData from cache: [${ tensor . dims } ]` ) ;
98100 }
@@ -104,7 +106,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
104106 * Usage = Encoder.Usage.Default.
105107 * @param dataType the tensor data type
106108 */
107- createTextureDataFromLayout ( layout : TextureLayout , dataType : Tensor . DataType ) : TextureData {
109+ createTextureDataFromLayout ( layout : TextureLayout , dataType : Tensor . DataType , isPacked = false ) : TextureData {
108110 return this . createTextureData ( layout , dataType ) ;
109111 }
110112
@@ -118,13 +120,14 @@ export class WebGLInferenceHandler implements InferenceHandler {
118120 * @param tensor the tensor to bind. tensor's data is ignored.
119121 */
120122 createTextureDataFromLayoutBindTensor (
121- layout : TextureLayout , dataType : Tensor . DataType , data : Tensor . NumberType , tensor : Tensor ) : TextureData {
122- return this . createTextureData ( layout , dataType , data , tensor , Encoder . Usage . UploadOnly ) ;
123+ layout : TextureLayout , dataType : Tensor . DataType , data : Tensor . NumberType , tensor : Tensor ,
124+ isPacked = false ) : TextureData {
125+ return this . createTextureData ( layout , dataType , data , tensor , Encoder . Usage . UploadOnly , isPacked ) ;
123126 }
124127
125128 private createTextureData (
126129 layout : TextureLayout , dataType : Tensor . DataType , data ?: Tensor . NumberType , tensor ?: Tensor ,
127- usage ?: Encoder . Usage ) : TextureData {
130+ usage ?: Encoder . Usage , isPacked = false ) : TextureData {
128131 Logger . verbose ( 'InferenceHandler' , `Creating TextureData: layout:[${ JSON . stringify ( layout ) } ]` ) ;
129132 const texture = this . session . textureManager . createTextureFromLayout ( dataType , layout , data , usage ) ;
130133 return this . createTextureDataFromTexture ( layout , dataType , texture , tensor ) ;
@@ -137,8 +140,9 @@ export class WebGLInferenceHandler implements InferenceHandler {
137140 * @param texture the WebGLTexture object to share
138141 * @param tensorId the tensor ID of the shared tensor data
139142 */
140- createSharedTextureData ( layout : TextureLayout , dataType : Tensor . DataType , texture : WebGLTexture , tensorId : Tensor . Id ) :
141- TextureData {
143+ createSharedTextureData (
144+ layout : TextureLayout , dataType : Tensor . DataType , texture : WebGLTexture , tensorId ?: Tensor . Id ,
145+ isPacked = false ) : TextureData {
142146 return this . createTextureDataFromTexture ( layout , dataType , texture , undefined , tensorId ) ;
143147 }
144148
@@ -155,29 +159,32 @@ export class WebGLInferenceHandler implements InferenceHandler {
155159 undefined , undefined , tensorId ) ,
156160 texture
157161 } ;
158- this . setTextureData ( textureData . tensor . dataId , textureData ) ;
162+ this . setTextureData ( textureData . tensor . dataId , textureData , layout . isPacked ) ;
159163 return textureData ;
160164 }
161165
162- getTextureData ( tensorId : Tensor . Id ) : TextureData | undefined {
163- return this . session . isInitializer ( tensorId ) ? this . session . getTextureData ( tensorId ) :
164- this . textureDataCache . get ( tensorId ) ;
166+ getTextureData ( tensorId : Tensor . Id , isPacked = false ) : TextureData | undefined {
167+ return this . session . isInitializer ( tensorId ) ?
168+ this . session . getTextureData ( tensorId , isPacked ) :
169+ isPacked ? this . packedTextureDataCache . get ( tensorId ) : this . unpackedTextureDataCache . get ( tensorId ) ;
165170 }
166- setTextureData ( tensorId : Tensor . Id , td : TextureData ) : void {
171+ setTextureData ( tensorId : Tensor . Id , td : TextureData , isPacked = false ) : void {
167172 if ( this . session . isInitializer ( tensorId ) ) {
168- this . session . setTextureData ( tensorId , td ) ;
173+ this . session . setTextureData ( tensorId , td , isPacked ) ;
169174 } else {
170- this . textureDataCache . set ( tensorId , td ) ;
175+ isPacked ? this . packedTextureDataCache . set ( tensorId , td ) : this . unpackedTextureDataCache . set ( tensorId , td ) ;
171176 }
172177 }
173-
178+ isTextureLayoutCached ( tensor : Tensor , isPacked = false ) : boolean {
179+ return ! ! this . getTextureData ( tensor . dataId , isPacked ) ;
180+ }
174181 /**
175182 * Create a TextureLayout object from a tensor. If a related texture data is found, returns the cached texture layout.
176183 */
177184 getOrCreateTextureLayout (
178185 tensor : Tensor , channels : 1 | 4 = 1 , isPacked = false , unpackedShape ?: ReadonlyArray < number > ,
179186 reverseWH = false ) : TextureLayout {
180- const td = this . getTextureData ( tensor . dataId ) ;
187+ const td = this . getTextureData ( tensor . dataId , isPacked ) ;
181188 if ( td ) {
182189 return td ;
183190 }
@@ -229,14 +236,17 @@ export class WebGLInferenceHandler implements InferenceHandler {
229236 isPacked,
230237 shape : inferredDims ,
231238 strides : ShapeUtil . computeStrides ( inferredDims ) ,
232- unpackedShape
239+ unpackedShape,
240+ reversedWH : ( prefs && prefs . reverseWH )
233241 } ;
234242 }
235243
236244 dispose ( ) : void {
237245 this . session . textureManager . clearActiveTextures ( ) ;
238- this . textureDataCache . forEach ( td => this . session . textureManager . releaseTexture ( td ) ) ;
239- this . textureDataCache = new Map ( ) ;
246+ this . packedTextureDataCache . forEach ( td => this . session . textureManager . releaseTexture ( td ) ) ;
247+ this . packedTextureDataCache = new Map ( ) ;
248+ this . unpackedTextureDataCache . forEach ( td => this . session . textureManager . releaseTexture ( td ) ) ;
249+ this . unpackedTextureDataCache = new Map ( ) ;
240250 }
241251
242252 readTexture ( textureData : TextureData ) : Tensor . NumberType {
@@ -252,6 +262,10 @@ export class WebGLInferenceHandler implements InferenceHandler {
252262 }
253263
254264 pack ( input : TextureData ) : TextureData {
265+ const cachedId = this . unpack2packMap . get ( input . tensor . dataId ) ;
266+ if ( cachedId ) {
267+ return this . packedTextureDataCache . get ( cachedId ) ! ;
268+ }
255269 const key = `${ input . shape } ` ;
256270 let op = this . session . packOpCache . get ( key ) ;
257271 if ( ! op ) {
@@ -266,10 +280,15 @@ export class WebGLInferenceHandler implements InferenceHandler {
266280 }
267281 const runData = op . createRunData ( this , artifact . programInfo , [ input . tensor ] ) ;
268282 this . runProgram ( artifact , runData ) ;
283+ this . unpack2packMap . set ( input . tensor . dataId , runData . outputTextureData . tensor . dataId ) ;
269284 return runData . outputTextureData ;
270285 }
271286
272287 unpack ( input : TextureData ) : TextureData {
288+ const cachedId = this . pack2unpackMap . get ( input . tensor . dataId ) ;
289+ if ( cachedId ) {
290+ return this . unpackedTextureDataCache . get ( cachedId ) ! ;
291+ }
273292 // For unpacked kernel, cache it by using input's unpackedShape as cache key.
274293 // Note that we need to use input.unpackedShape instead of input.shape here,
275294 // as the shape infers the packed texture shape. Different unpackedShape can have the
@@ -290,6 +309,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
290309 }
291310 const runData = op . createRunData ( this , artifact . programInfo , [ input . tensor ] ) ;
292311 this . runProgram ( artifact , runData ) ;
312+ this . pack2unpackMap . set ( input . tensor . dataId , runData . outputTextureData . tensor . dataId ) ;
293313 return runData . outputTextureData ;
294314 }
295315}
0 commit comments