Skip to content

Add "load_by_name" API at wasi-nn #4267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

HongxiaWangSSSS
Copy link
Contributor

When using WASI-NN ,we want to reduce copying the AI model from the host to the WASM by using "load_by_name" API.
We also want to use it to improve the performance, keep the safety and WASI-NN also supports this method on different backends.

I test with 3 tflite models.(x86_64, Ubuntu 22.04)

image

Both coco_ssd_mobilenet_v1 and coco_ssd_mobilenet_v3 are for detection, its file size and input tensor size is different.
mobilenet_v2 is for classification and size is more bigger.

The time consumed does not show a linear growth as the file size increases.
For most cases, load_by_name will more faster whether the load tflite or the entire inference process (load+ set input +compute +get output).

@HongxiaWangSSSS HongxiaWangSSSS marked this pull request as ready for review May 12, 2025 05:24
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug fix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its a bugfix,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lum1n0us
Where can I find the function prototypes for wasi_ephemeral_nn?
For functions like get_output, the signatures are different, so simply replacing wasi_nn with wasi_ephemeral_nn mechanically doesn't seem to work.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
REG_NATIVE_FUNC(load, "(*iii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(load_by_name_with_config, "(*i*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii**)i"),
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Even when referring to the wasi-nn specification, the signatures declared there don’t appear to match what is used for wasi_ephemeral_nn
https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L59-L86

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HongxiaWangSSSS I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side and there will be two sets for wasi_ephemeral_nn and wasi_nn. wasi_nn will be deprecated in the future.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name(const char *name, uint32_t len, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name_with_config(const char *name, uint32_t name_len, void *config, uint32_t config_len, graph *g)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
init_execution_context(graph g, graph_execution_context *exec_ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
compute(graph_execution_context ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));;

#else

wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
init_execution_context(graph g, graph_execution_context *ctx)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)
    __attribute__((import_module("wasi_nn")));
#endif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mismatch in the function signature?

REG_NATIVE_FUNC(get_output, "(ii*i*)i"),

and

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size) attribute((import_module("wasi_ephemeral_nn")));

Copy link
Collaborator

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES. there should be two versions, one for wasi_ephemeral_nn, another for wasi_nn. please refer to:

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
             uint32_t builder_wasm_size, graph_encoding encoding,
             execution_target target, graph *g)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
             graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */


#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, we need to define it at wasi_nn.h

wasi_nn_error
 get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t output_tensor_len, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));

But the backend definition doesn't look like it matches (signature might be ok)

__attribute__((visibility("default"))) wasi_nn_error
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)

Copy link
Collaborator

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I am suggestion this #4267 (comment) as the new content of wasi_nn.h. Plus, #4267 (comment).

@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use -DWASM_ENABLE_WASI_EPHEMERAL_NN=1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at first, i just built with -DWASM_ENABLE_WASI_NN=1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 during compilation, you will be able to use the set of APIs, including load_by_name(). There is no need to change this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 must be added , I think there is no need to add this line.
Currently, it is possible to use the default wasi-nn instead of wasi_ephemeral_nn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Please do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, does that mean that the ephemeral version is meant to be compatible with Rust (especially WasmEdge), whereas the non-ephemeral one doesn't need to be?
If that's the case, wouldn't it make sense to add load_by_name to the wasi_nn.h header, which is a straightforward C interpretation of the witx specification?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. load_by_name() needs a new signature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, load_by_name() and wasm_load_by_name() should only exist when WASM_ENABLE_WASI_EPHEMERAL_NN is set to 1. If you transform wasm with the flag WASM_ENABLE_WASI_EPHEMERAL_NN=1, you will get wasm with import requirements from wasi_ephemeral_nn, which offers better performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4267 (comment)
said this is for C langage, so why not keep it for non-wasi_ephemeral_nn to provide the same performance ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, wasi_ephemeral_nn is legacy and should be deprecated. If there are C APIs, they should follow the Rust API's design to avoid unnecessary changes for the runtime.

@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required to validate TFLitesContext.models[g] for both cases, using load() and load_by_name(). It will not be acceptable if the change disables one of these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it is load or load_by_name, the check of models[g].model_pointerdoes not seem to be necessary, just make sure the models[g].model is not NULL maybe is enough.
Do you have any idea?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this operation, the models[g].model_pointer's connect has been saved in models[g].model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/tensorflow/tensorflow/blob/02896e880298894beedc71f8666f6949ad0e174f/tensorflow/compiler/mlir/lite/core/model_builder_base.h#L169-#L184

  /// Builds a model based on a pre-loaded flatbuffer.
  /// Caller retains ownership of the buffer and should keep it alive until
  /// the returned object is destroyed. Caller also retains ownership of
  /// `error_reporter` and must ensure its lifetime is longer than the
  /// FlatBufferModelBase instance.
  /// Returns a nullptr in case of failure.
  /// NOTE: this does NOT validate the buffer so it should NOT be called on
  /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
  static std::unique_ptr<T> BuildFromBuffer(
      const char* caller_owned_buffer, size_t buffer_size,
      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
    error_reporter = ValidateErrorReporter(error_reporter);
    std::unique_ptr<Allocation> allocation(
        new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
    return BuildFromAllocation(std::move(allocation), error_reporter);
  }

If I understand correctly, model_pointer acts as a pre-allocated buffer, and its ownership is still held by the caller, in our case, tfl_ctx. Meanwhile, model holds the ownership of the result from tflite::FlatBufferModel::BuildFromBuffer(). Therefore, both are required.

@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Please do it.

@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

@@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better be

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
    wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
    return res;
}
#endif 

@@ -15,6 +15,44 @@
#include <stdint.h>
#include "wasi_nn_types.h"

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ayakoakasaka and @lum1n0us ,
For wasi_nn and wasi_ephemeral_nn, different APIs need to be defined in the header file.
However, currently using wasi_ephemeral_nn, I find that I get an error when I calling set_input, ->Content is inconsistent when passed from wasm to native.
So can we divide it into two PRs, first support load_by_name in wasi-nn, and then implement support for wasi_ephemeral_nn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. As a workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants