Skip to content

Commit 0b978dc

Browse files
committed
Update coreml.js (#1203)
1 parent 060aa4b commit 0b978dc

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

source/coreml.js

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,13 @@ coreml.Graph = class {
223223
this.type = context.type;
224224
this.groups = context.groups;
225225
for (const value of context.values.values()) {
226-
const name = value.name || '';
227-
const type = value.type || null;
228-
const description = value.description || null;
229-
const initializer = value.initializer || null;
230-
value.obj = new coreml.Value(name, type, description, initializer);
226+
const name = value.name;
227+
const type = value.type;
228+
const description = value.description;
229+
const initializer = value.initializer;
230+
if (!value.obj) {
231+
value.obj = new coreml.Value(name, type, description, initializer);
232+
}
231233
}
232234
this.inputs = context.inputs.map((argument) => {
233235
const values = argument.value.map((value) => value.obj);
@@ -237,14 +239,21 @@ coreml.Graph = class {
237239
const values = argument.value.map((value) => value.obj);
238240
return new coreml.Argument(argument.name, argument.visible, values);
239241
});
240-
/*
241242
for (const obj of context.nodes) {
242-
if (obj.type === 'loop') {
243-
obj.attributes.conditionNetwork = new coreml.Graph(obj.attributes.conditionNetwork);
244-
obj.attributes.bodyNetwork = new coreml.Graph(obj.attributes.bodyNetwork);
243+
const attributes = obj.attributes;
244+
switch (obj.type) {
245+
case 'loop':
246+
attributes.conditionNetwork = new coreml.Graph(attributes.conditionNetwork);
247+
attributes.bodyNetwork = new coreml.Graph(attributes.bodyNetwork);
248+
break;
249+
case 'branch':
250+
attributes.ifBranch = new coreml.Graph(attributes.ifBranch);
251+
attributes.elseBranch = new coreml.Graph(attributes.elseBranch);
252+
break;
253+
default:
254+
break;
245255
}
246256
}
247-
*/
248257
this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
249258
}
250259
};
@@ -554,12 +563,14 @@ coreml.Context = class {
554563
}
555564

556565
network(obj) {
566+
const context = this.context();
557567
for (const layer of obj.layers) {
558568
const type = layer.layer;
559-
this.node(this.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
569+
context.node(context.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
560570
}
561-
this.updatePreprocessing('', obj.preprocessing, null);
562-
this.type = 'Neural Network';
571+
context.updatePreprocessing('', obj.preprocessing, null);
572+
context.type = 'Neural Network';
573+
return context;
563574
}
564575

565576
input(name) {
@@ -777,23 +788,25 @@ coreml.Context = class {
777788
}
778789
};
779790
if (data) {
791+
const attributes = obj.attributes;
780792
const map = weights(type, data, initializers);
781793
for (const [name, value] of Object.entries(data)) {
782794
if (!map[name]) {
783-
obj.attributes[name] = value;
795+
attributes[name] = value;
784796
}
785797
}
786-
/*
787-
if (obj.type === 'loop') {
788-
const network = (context, obj) => {
789-
context = context.context();
790-
context.network(obj);
791-
return context;
792-
};
793-
obj.attributes.bodyNetwork = network(this, obj.attributes.bodyNetwork);
794-
obj.attributes.conditionNetwork = network(this, obj.attributes.conditionNetwork);
798+
switch (obj.type) {
799+
case 'loop':
800+
attributes.bodyNetwork = this.network(attributes.bodyNetwork);
801+
attributes.conditionNetwork = this.network(attributes.conditionNetwork);
802+
break;
803+
case 'branch':
804+
attributes.ifBranch = this.network(attributes.ifBranch);
805+
attributes.elseBranch = this.network(attributes.elseBranch);
806+
break;
807+
default:
808+
break;
795809
}
796-
*/
797810
}
798811
const metadata = this.metadata.type(type);
799812
for (let i = 0; i < inputs.length;) {

0 commit comments

Comments
 (0)