@@ -223,11 +223,13 @@ coreml.Graph = class {
223
223
this . type = context . type ;
224
224
this . groups = context . groups ;
225
225
for ( const value of context . values . values ( ) ) {
226
- const name = value . name || '' ;
227
- const type = value . type || null ;
228
- const description = value . description || null ;
229
- const initializer = value . initializer || null ;
230
- value . obj = new coreml . Value ( name , type , description , initializer ) ;
226
+ const name = value . name ;
227
+ const type = value . type ;
228
+ const description = value . description ;
229
+ const initializer = value . initializer ;
230
+ if ( ! value . obj ) {
231
+ value . obj = new coreml . Value ( name , type , description , initializer ) ;
232
+ }
231
233
}
232
234
this . inputs = context . inputs . map ( ( argument ) => {
233
235
const values = argument . value . map ( ( value ) => value . obj ) ;
@@ -237,14 +239,21 @@ coreml.Graph = class {
237
239
const values = argument . value . map ( ( value ) => value . obj ) ;
238
240
return new coreml . Argument ( argument . name , argument . visible , values ) ;
239
241
} ) ;
240
- /*
241
242
for ( const obj of context . nodes ) {
242
- if (obj.type === 'loop') {
243
- obj.attributes.conditionNetwork = new coreml.Graph(obj.attributes.conditionNetwork);
244
- obj.attributes.bodyNetwork = new coreml.Graph(obj.attributes.bodyNetwork);
243
+ const attributes = obj . attributes ;
244
+ switch ( obj . type ) {
245
+ case 'loop' :
246
+ attributes . conditionNetwork = new coreml . Graph ( attributes . conditionNetwork ) ;
247
+ attributes . bodyNetwork = new coreml . Graph ( attributes . bodyNetwork ) ;
248
+ break ;
249
+ case 'branch' :
250
+ attributes . ifBranch = new coreml . Graph ( attributes . ifBranch ) ;
251
+ attributes . elseBranch = new coreml . Graph ( attributes . elseBranch ) ;
252
+ break ;
253
+ default :
254
+ break ;
245
255
}
246
256
}
247
- */
248
257
this . nodes = context . nodes . map ( ( obj ) => new coreml . Node ( context , obj ) ) ;
249
258
}
250
259
} ;
@@ -554,12 +563,14 @@ coreml.Context = class {
554
563
}
555
564
556
565
network ( obj ) {
566
+ const context = this . context ( ) ;
557
567
for ( const layer of obj . layers ) {
558
568
const type = layer . layer ;
559
- this . node ( this . groups , type , layer . name , '' , layer [ type ] , layer . input , layer . output , layer . inputTensor , layer . outputTensor ) ;
569
+ context . node ( context . groups , type , layer . name , '' , layer [ type ] , layer . input , layer . output , layer . inputTensor , layer . outputTensor ) ;
560
570
}
561
- this . updatePreprocessing ( '' , obj . preprocessing , null ) ;
562
- this . type = 'Neural Network' ;
571
+ context . updatePreprocessing ( '' , obj . preprocessing , null ) ;
572
+ context . type = 'Neural Network' ;
573
+ return context ;
563
574
}
564
575
565
576
input ( name ) {
@@ -777,23 +788,25 @@ coreml.Context = class {
777
788
}
778
789
} ;
779
790
if ( data ) {
791
+ const attributes = obj . attributes ;
780
792
const map = weights ( type , data , initializers ) ;
781
793
for ( const [ name , value ] of Object . entries ( data ) ) {
782
794
if ( ! map [ name ] ) {
783
- obj . attributes [ name ] = value ;
795
+ attributes [ name ] = value ;
784
796
}
785
797
}
786
- /*
787
- if (obj.type === 'loop') {
788
- const network = (context, obj) => {
789
- context = context.context();
790
- context.network(obj);
791
- return context;
792
- };
793
- obj.attributes.bodyNetwork = network(this, obj.attributes.bodyNetwork);
794
- obj.attributes.conditionNetwork = network(this, obj.attributes.conditionNetwork);
798
+ switch ( obj . type ) {
799
+ case 'loop' :
800
+ attributes . bodyNetwork = this . network ( attributes . bodyNetwork ) ;
801
+ attributes . conditionNetwork = this . network ( attributes . conditionNetwork ) ;
802
+ break ;
803
+ case 'branch' :
804
+ attributes . ifBranch = this . network ( attributes . ifBranch ) ;
805
+ attributes . elseBranch = this . network ( attributes . elseBranch ) ;
806
+ break ;
807
+ default :
808
+ break ;
795
809
}
796
- */
797
810
}
798
811
const metadata = this . metadata . type ( type ) ;
799
812
for ( let i = 0 ; i < inputs . length ; ) {
0 commit comments