Skip to content

Commit

Permalink
Update onnx.js (#1387)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 9, 2024
1 parent 7b22246 commit f49812a
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 58 deletions.
83 changes: 58 additions & 25 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ onnx.Tensor = class {
const offset = parseInt(external_data.offset, 10);
const length = parseInt(external_data.length, 10);
if (Number.isInteger(offset) && Number.isInteger(length)) {
this._data = context.location(external_data.location, offset, length);
const location = context.location(external_data.location);
this._request = { location, offset, length };
this._encoding = '<';
}
}
Expand All @@ -661,6 +662,20 @@ onnx.Tensor = class {
}
}

peek() {
return !this._request;
}

async read() {
if (this._request) {
const location = this._request.location;
const offset = this._request.offset;
const length = this._request.length;
this._data = await location.read(offset, length);
delete this._request;
}
}

get name() {
return this._name;
}
Expand All @@ -686,6 +701,9 @@ onnx.Tensor = class {
}

get values() {
if (this._request) {
throw new onnx.Error('Tensor data not loaded.');
}
switch (this.type.layout) {
case 'sparse': {
return this._values;
Expand Down Expand Up @@ -980,20 +998,9 @@ onnx.Context.Model = class {
return this._graph;
}

location(name, offset, length) {
location(name) {
if (this._locations.has(name)) {
const stream = this._locations.get(name);
if (offset >= 0 && (offset + length) <= stream.length) {
try {
const position = stream.position;
stream.seek(offset);
const value = stream.stream(length);
stream.seek(position);
return value;
} catch {
// continue regardless of error
}
}
return this._locations.get(name);
}
return null;
}
Expand Down Expand Up @@ -1263,8 +1270,8 @@ onnx.Context.Graph = class {
return this._tensors.get(name);
}

location(name, offset, length) {
return this._context.location(name, offset, length);
location(name) {
return this._context.location(name);
}

group(name) {
Expand Down Expand Up @@ -1706,16 +1713,12 @@ onnx.ProtoReader = class {
}
}
}
for (const key of this.locations.keys()) {
locations.delete(key);
for (const name of this.locations.keys()) {
locations.delete(name);
}
const keys = Array.from(locations);
const promises = keys.map((location) => this.context.fetch(location));
const streams = await Promise.all(promises.map((promise) => promise.then((context) => context.stream).catch(() => null)));
for (let i = 0; i < keys.length; i++) {
if (streams[i] !== null) {
this.locations.set(keys[i], streams[i]);
}
for (const name of locations) {
const location = new onnx.Location(this.context, name);
this.locations.set(name, location);
}
}
};
Expand Down Expand Up @@ -2979,6 +2982,36 @@ onnx.DataReader = class {
}
};

onnx.Location = class {

constructor(context, name) {
this.context = context;
this.name = name;
this.content = new Map();
}

async read(offset, length) {
const key = `${offset}:${length}`;
if (this.content.has(key)) {
return this.content.get(key);
}
if (!this.promise) {
this.promise = this.context.fetch(this.name);
}
return this.promise.then((content) => {
const stream = content.stream;
const position = stream.position;
stream.seek(offset);
content = stream.stream(length);
stream.seek(position);
this.content.set(key, content);
return content;
}).catch(() => {
return null;
});
}
};

onnx.Error = class extends Error {

constructor(message) {
Expand Down
60 changes: 35 additions & 25 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -3259,24 +3259,30 @@ view.TensorView = class extends view.Expander {
content.innerHTML = `Tensor encoding '${tensor.layout}' is not implemented.`;
} else if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
content.innerHTML = `Tensor layout '${tensor.layout}' is not implemented.`;
} else if (tensor.empty) {
content.innerHTML = 'Tensor data is empty.';
} else if (tensor.type && tensor.type.dataType === '?') {
content.innerHTML = 'Tensor data type is not defined.';
} else if (tensor.type && !tensor.type.shape) {
content.innerHTML = 'Tensor shape is not defined.';
} else {
content.innerHTML = tensor.toString();
if (this._host.save && value.type.shape && value.type.shape.dimensions && value.type.shape.dimensions.length > 0) {
this._saveButton = this.createElement('div', 'sidebar-item-value-button');
this._saveButton.classList.add('sidebar-item-value-button-context');
this._saveButton.setAttribute('style', 'float: right;');
this._saveButton.innerHTML = '&#x1F4BE;';
this._saveButton.addEventListener('click', async () => {
await this.export();
});
content.insertBefore(this._saveButton, content.firstChild);
}
content.innerHTML = '&#x23F3';
const promise = value.peek && !value.peek() ? value.read() : Promise.resolve();
promise.then(() => {
if (tensor.empty) {
content.innerHTML = 'Tensor data is empty.';
} else {
content.innerHTML = tensor.toString();
if (this._host.save && value.type.shape && value.type.shape.dimensions && value.type.shape.dimensions.length > 0) {
this._saveButton = this.createElement('div', 'sidebar-item-value-button');
this._saveButton.classList.add('sidebar-item-value-button-context');
this._saveButton.setAttribute('style', 'float: right;');
this._saveButton.innerHTML = '&#x1F4BE;';
this._saveButton.addEventListener('click', async () => {
await this.export();
});
content.insertBefore(this._saveButton, content.firstChild);
}
}
});
}
return content;
}
Expand Down Expand Up @@ -3457,7 +3463,6 @@ view.TensorSidebar = class extends view.ObjectSidebar {
constructor(context, value) {
super(context);
this._value = value;
this._tensor = new base.Tensor(value.value[0].initializer);
}

get identifier() {
Expand Down Expand Up @@ -3515,18 +3520,23 @@ view.TensorSidebar = class extends view.ObjectSidebar {
}
// Metrics
if (value.initializer) {
if (!this._tensor.empty) {
if (!this._metrics) {
this._metrics = new metrics.Tensor(this._tensor);
}
if (this._metrics.metrics.length > 0) {
this.addHeader('Metrics');
for (const metric of this._metrics.metrics) {
const value = metric.type === 'percentage' ? `${(metric.value * 100).toFixed(1)}%` : metric.value;
this.addProperty(metric.name, [value]);
const tensor = value.initializer;
const promise = tensor.peek && !tensor.peek() ? tensor.read() : Promise.resolve();
promise.then(() => {
this._tensor = new base.Tensor(tensor);
if (!this._tensor.empty) {
if (!this._metrics) {
this._metrics = new metrics.Tensor(this._tensor);
}
if (this._metrics.metrics.length > 0) {
this.addHeader('Metrics');
for (const metric of this._metrics.metrics) {
const value = metric.type === 'percentage' ? `${(metric.value * 100).toFixed(1)}%` : metric.value;
this.addProperty(metric.name, [value]);
}
}
}
}
});
}
}

Expand Down Expand Up @@ -4449,7 +4459,7 @@ view.Formatter = class {
static tensor(value) {
const type = value.type;
if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) {
if (type.shape.dimensions.length === 0) {
if (type.shape.dimensions.length === 0 && (!value.peek || value.peek() === true)) {
const tensor = new base.Tensor(value);
const encoding = tensor.encoding;
if ((encoding === '<' || encoding === '>' || encoding === '|') && !tensor.empty && tensor.type.dataType !== '?') {
Expand Down
31 changes: 23 additions & 8 deletions test/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ export class Target {
// continue
}
/* eslint-disable no-unused-expressions */
const validateGraph = (graph) => {
const validateGraph = async (graph) => {
const values = new Map();
const validateValue = (value) => {
const validateValue = async (value) => {
if (value === null) {
return;
}
Expand All @@ -643,6 +643,9 @@ export class Target {
}
if (value.initializer) {
value.initializer.type.toString();
if (value.initializer && value.initializer.peek && !value.initializer.peek()) {
await value.initializer.read();
}
const tensor = new base.Tensor(value.initializer);
if (!this.tags.has('skip-tensor-value')) {
if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
Expand Down Expand Up @@ -700,14 +703,18 @@ export class Target {
input.name.toString();
input.name.length;
for (const value of input.value) {
validateValue(value);
/* eslint-disable no-await-in-loop */
await validateValue(value);
/* eslint-enable no-await-in-loop */
}
}
for (const output of signature.outputs) {
output.name.toString();
output.name.length;
for (const value of output.value) {
validateValue(value);
/* eslint-disable no-await-in-loop */
await validateValue(value);
/* eslint-enable no-await-in-loop */
}
}
}
Expand All @@ -720,7 +727,9 @@ export class Target {
throw new Error(`Invalid node type '${JSON.stringify(node.type)}'.`);
}
if (Array.isArray(type.nodes)) {
validateGraph(type);
/* eslint-disable no-await-in-loop */
await validateGraph(type);
/* eslint-enable no-await-in-loop */
}
view.Documentation.open(type);
node.name.toString();
Expand All @@ -736,7 +745,9 @@ export class Target {
const type = attribute.type;
const value = attribute.value;
if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
validateGraph(value);
/* eslint-disable no-await-in-loop */
await validateGraph(value);
/* eslint-enable no-await-in-loop */
} else {
let text = new view.Formatter(attribute.value, attribute.type).toString();
if (text && text.length > 1000) {
Expand All @@ -753,7 +764,9 @@ export class Target {
input.name.length;
if (!input.type || input.type.endsWith('*')) {
for (const value of input.value) {
validateValue(value);
/* eslint-disable no-await-in-loop */
await validateValue(value);
/* eslint-enable no-await-in-loop */
}
if (this.tags.has('validation')) {
if (input.value.length === 1 && input.value[0].initializer) {
Expand All @@ -771,7 +784,9 @@ export class Target {
output.name.length;
if (!output.type || output.type.endsWith('*')) {
for (const value of output.value) {
validateValue(value);
/* eslint-disable no-await-in-loop */
await validateValue(value);
/* eslint-enable no-await-in-loop */
}
}
}
Expand Down

0 comments on commit f49812a

Please sign in to comment.