-
Notifications
You must be signed in to change notification settings - Fork 694
Add "load_by_name" API at wasi-nn #4267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "load_by_name" API at wasi-nn #4267
Conversation
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a bug fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, its a bugfix,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.
- I suggest we use
WASM_ENABLE_WASI_EPHEMERAL_NN
on the wasm side. - With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of
native_symbols_wasi_nn
in wasi_nn.c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lum1n0us
Where can I find the function prototypes for wasi_ephemeral_nn
?
For functions like get_output
, the signatures are different, so simply replacing wasi_nn with wasi_ephemeral_nn mechanically doesn't seem to work.
wasm-micro-runtime/core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Lines 690 to 704 in c018b8a
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 | |
REG_NATIVE_FUNC(load, "(*iii*)i"), | |
REG_NATIVE_FUNC(load_by_name, "(*i*)i"), | |
REG_NATIVE_FUNC(load_by_name_with_config, "(*i*i*)i"), | |
REG_NATIVE_FUNC(init_execution_context, "(i*)i"), | |
REG_NATIVE_FUNC(set_input, "(ii*)i"), | |
REG_NATIVE_FUNC(compute, "(i)i"), | |
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), | |
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ | |
REG_NATIVE_FUNC(load, "(*ii*)i"), | |
REG_NATIVE_FUNC(init_execution_context, "(i*)i"), | |
REG_NATIVE_FUNC(set_input, "(ii*)i"), | |
REG_NATIVE_FUNC(compute, "(i)i"), | |
REG_NATIVE_FUNC(get_output, "(ii**)i"), | |
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ |
Even when referring to the wasi-nn specification, the signatures declared there don’t appear to match what is used for wasi_ephemeral_nn
https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L59-L86
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HongxiaWangSSSS I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN
on the wasm side and there will be two sets for wasi_ephemeral_nn and wasi_nn. wasi_nn will be deprecated in the future.
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g)
__attribute__((import_module("wasi_ephemeral_nn")));
wasi_nn_error
load_by_name(const char *name, uint32_t len, graph *g)
__attribute__((import_module("wasi_ephemeral_nn")));
wasi_nn_error
load_by_name_with_config(const char *name, uint32_t name_len, void *config, uint32_t config_len, graph *g) __attribute__((import_module("wasi_ephemeral_nn")));;
wasi_nn_error
init_execution_context(graph g, graph_execution_context *exec_ctx) __attribute__((import_module("wasi_ephemeral_nn")));;
wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor) __attribute__((import_module("wasi_ephemeral_nn")));;
wasi_nn_error
compute(graph_execution_context ctx) __attribute__((import_module("wasi_ephemeral_nn")));;
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size) __attribute__((import_module("wasi_ephemeral_nn")));;
#else
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g)
__attribute__((import_module("wasi_nn")));
wasi_nn_error
init_execution_context(graph g, graph_execution_context *ctx)
__attribute__((import_module("wasi_nn")));
wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
__attribute__((import_module("wasi_nn")));
wasi_nn_error
compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
__attribute__((import_module("wasi_nn")));
#endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a mismatch in the function signature?
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), |
and
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size) attribute((import_module("wasi_ephemeral_nn")));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YES. there should be two versions, one for wasi_ephemeral_nn, another for wasi_nn. please refer to:
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
uint32_t builder_wasm_size, graph_encoding encoding,
execution_target target, graph *g)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, we need to define it at wasi_nn.h
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t output_tensor_len, uint32_t *output_tensor_size) __attribute__((import_module("wasi_ephemeral_nn")));
But the backend definition doesn't look like it matches (signature might be ok)
wasm-micro-runtime/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Lines 366 to 368 in c018b8a
__attribute__((visibility("default"))) wasi_nn_error | |
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, | |
tensor_data output_tensor, uint32_t *output_tensor_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I am suggestion this #4267 (comment) as the new content of wasi_nn.h. Plus, #4267 (comment).
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { | |||
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), | |||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ | |||
REG_NATIVE_FUNC(load, "(*ii*)i"), | |||
REG_NATIVE_FUNC(load_by_name, "(*i*)i"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use -DWASM_ENABLE_WASI_EPHEMERAL_NN=1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at first, i just built with -DWASM_ENABLE_WASI_NN=1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If use -DWAMR_BUILD_WASI_EPHEMERAL_NN=1
during compilation, you will be able to use the set of APIs, including load_by_name()
. There is no need to change this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If -DWAMR_BUILD_WASI_EPHEMERAL_NN=1
must be added , I think there is no need to add this line.
Currently, it is possible to use the default wasi-nn
instead of wasi_ephemeral_nn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Please do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, does that mean that the ephemeral version is meant to be compatible with Rust (especially WasmEdge), whereas the non-ephemeral one doesn't need to be?
If that's the case, wouldn't it make sense to add load_by_name
to the wasi_nn.h header, which is a straightforward C interpretation of the witx specification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. load_by_name()
needs a new signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, load_by_name()
and wasm_load_by_name()
should only exist when WASM_ENABLE_WASI_EPHEMERAL_NN
is set to 1. If you transform wasm with the flag WASM_ENABLE_WASI_EPHEMERAL_NN=1
, you will get wasm with import requirements from wasi_ephemeral_nn
, which offers better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#4267 (comment)
said this is for C langage, so why not keep it for non-wasi_ephemeral_nn
to provide the same performance ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind, wasi_ephemeral_nn is legacy and should be deprecated. If there are C APIs, they should follow the Rust API's design to avoid unnecessary changes for the runtime.
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?
@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) | |||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); | |||
return runtime_error; | |||
} | |||
if (tfl_ctx->models[g].model_pointer == NULL) { | |||
if (tfl_ctx->models[g].model_pointer == NULL | |||
&& tfl_ctx->models[g].model == NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If call load_by_name(), there is no need to save the tflite buf to model_pointer.
https://github.com/HongxiaWangSSSS/wasm-micro-runtime/blob/c5414fd28baf973e3c95db1318de4d26f88007d3/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp#L141C49-L141C68
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why not free model_point after below operation in load()
https://github.com/HongxiaWangSSSS/wasm-micro-runtime/blob/c5414fd28baf973e3c95db1318de4d26f88007d3/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp#L151
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is required to validate TFLitesContext.models[g]
for both cases, using load()
and load_by_name()
. It will not be acceptable if the change disables one of these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether it is load
or load_by_name
, the check of models[g].model_pointer
does not seem to be necessary, just make sure the models[g].model
is not NULL maybe is enough.
Do you have any idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not seem to be necessary.
Why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after this operation, the models[g].model_pointer
's connect has been saved in models[g].model
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Builds a model based on a pre-loaded flatbuffer.
/// Caller retains ownership of the buffer and should keep it alive until
/// the returned object is destroyed. Caller also retains ownership of
/// `error_reporter` and must ensure its lifetime is longer than the
/// FlatBufferModelBase instance.
/// Returns a nullptr in case of failure.
/// NOTE: this does NOT validate the buffer so it should NOT be called on
/// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
static std::unique_ptr<T> BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return BuildFromAllocation(std::move(allocation), error_reporter);
}
If I understand correctly, model_pointer
acts as a pre-allocated buffer, and its ownership is still held by the caller, in our case, tfl_ctx
. Meanwhile, model
holds the ownership of the result from tflite::FlatBufferModel::BuildFromBuffer()
. Therefore, both are required.
@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) | |||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); | |||
return runtime_error; | |||
} | |||
if (tfl_ctx->models[g].model_pointer == NULL) { | |||
if (tfl_ctx->models[g].model_pointer == NULL | |||
&& tfl_ctx->models[g].model == NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not seem to be necessary.
Why is that?
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { | |||
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), | |||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ | |||
REG_NATIVE_FUNC(load, "(*ii*)i"), | |||
REG_NATIVE_FUNC(load_by_name, "(*i*)i"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Please do it.
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.
- I suggest we use
WASM_ENABLE_WASI_EPHEMERAL_NN
on the wasm side. - With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of
native_symbols_wasi_nn
in wasi_nn.c
@@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target) | |||
wasi_nn_error | |||
wasm_load_by_name(const char *model_name, graph *g) | |||
{ | |||
wasi_nn_error res = load_by_name(model_name, g); | |||
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better be
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
return res;
}
#endif
@@ -15,6 +15,44 @@ | |||
#include <stdint.h> | |||
#include "wasi_nn_types.h" | |||
|
|||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ayakoakasaka and @lum1n0us ,
For wasi_nn
and wasi_ephemeral_nn
, different APIs need to be defined in the header file.
However, currently using wasi_ephemeral_nn
, I find that I get an error when I calling set_input
, ->Content is inconsistent when passed from wasm to native.
So can we divide it into two PRs, first support load_by_name
in wasi-nn
, and then implement support for wasi_ephemeral_nn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. As a workaround.
When using WASI-NN ,we want to reduce copying the AI model from the host to the WASM by using "load_by_name" API.
We also want to use it to improve the performance, keep the safety and WASI-NN also supports this method on different backends.
I test with 3 tflite models.(x86_64, Ubuntu 22.04)
Both coco_ssd_mobilenet_v1 and coco_ssd_mobilenet_v3 are for detection, its file size and input tensor size is different.
mobilenet_v2 is for classification and size is more bigger.
The time consumed does not show a linear growth as the file size increases.
For most cases, load_by_name will more faster whether the load tflite or the entire inference process (load+ set input +compute +get output).