Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] enable WebGPU EP in WebAssembly build #23697

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,13 @@ if (onnxruntime_USE_WEBGPU)
set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE)
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(DAWN_EMSCRIPTEN_TOOLCHAIN "${REPO_ROOT}/cmake/external/emsdk/upstream/emscripten" CACHE STRING "" FORCE)

# Update a few files in Emscripten
#
# The following files should be updated in Emscripten. We are waiting for the next Emscripten release to include
# these changes. For now, we apply the changes manually.
# - ${DAWN_EMSCRIPTEN_TOOLCHAIN}/src/closure-externs/webgpu-externs.js
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${PROJECT_SOURCE_DIR}/patches/emscripten/webgpu-externs.js" "${DAWN_EMSCRIPTEN_TOOLCHAIN}/src/closure-externs/webgpu-externs.js")
else()
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE)
Expand Down
40 changes: 30 additions & 10 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,14 @@ else()
target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard)
endif()

set(onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre.js")

set(EXPORTED_FUNCTIONS "_malloc,_free")
if (onnxruntime_USE_JSEP)
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
string(APPEND EXPORTED_FUNCTIONS ",_JsepOutput,_JsepGetNodeName")
endif()
if (onnxruntime_USE_WEBGPU)
string(APPEND EXPORTED_FUNCTIONS ",_wgpuBufferRelease,_wgpuCreateInstance")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
Expand Down Expand Up @@ -312,13 +316,15 @@ else()
target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
endif()
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js"
"SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js\""
)
list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js")
else ()
set(MAXIMUM_MEMORY "4294967296")
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js"
"SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\""
)
list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js")
endif ()

target_link_options(onnxruntime_webassembly PRIVATE
Expand Down Expand Up @@ -372,7 +378,6 @@ jsepDownload:_pp_")
"SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'"
)
endif ()
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js)

if (onnxruntime_USE_JSEP)
# NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU
Expand All @@ -382,10 +387,8 @@ jsepDownload:_pp_")
target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1)
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js")

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
target_link_options(onnxruntime_webassembly PRIVATE
Expand All @@ -397,6 +400,20 @@ jsepDownload:_pp_")

if (onnxruntime_USE_WEBGPU)
target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1)
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js\""
)
list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js")
endif()

if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN)
# if any of the above is enabled, we need to use the asyncify library
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
)
list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js")
endif()

if (onnxruntime_EMSCRIPTEN_SETTINGS)
Expand Down Expand Up @@ -426,7 +443,8 @@ jsepDownload:_pp_")
"SHELL:-s ASSERTIONS=0"
"SHELL:-s SAFE_HEAP=0"
"SHELL:-s STACK_OVERFLOW_CHECK=0"
--closure 1
## comment out closure compiler so that it's easier to debug
# --closure 1
)
endif()

Expand Down Expand Up @@ -458,6 +476,8 @@ jsepDownload:_pp_")
)
endif()

set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS "${onnxruntime_webassembly_script_deps}")

set(target_name_list ort)

if (onnxruntime_ENABLE_TRAINING_APIS)
Expand Down
162 changes: 162 additions & 0 deletions cmake/patches/dawn/dawn.patch
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,165 @@ index 6e8ae37593..633af91eef 100644
-q
"${EM_BUILD_GEN_DIR}/struct_info_webgpu.json"
"-I=${EM_BUILD_GEN_DIR}/include"
diff --git a/src/emdawnwebgpu/README.md b/src/emdawnwebgpu/README.md
index efd6491cd6..8ebc5d28b6 100644
--- a/src/emdawnwebgpu/README.md
+++ b/src/emdawnwebgpu/README.md
@@ -56,7 +56,7 @@ Set up the build directory using emcmake
mkdir out/cmake-wasm
cd out/cmake-wasm

-# Make sure the path is to the source checkout of Emscripten, not emsdk's release.
+# If using Emscripten v4.0.2 or lower, make sure the path is to the source checkout of Emscripten, not emsdk's release.
emcmake cmake -GNinja -DDAWN_EMSCRIPTEN_TOOLCHAIN="path/to/emscripten" ../..

ninja
diff --git a/third_party/emdawnwebgpu/library_webgpu.js b/third_party/emdawnwebgpu/library_webgpu.js
index 5862ce4045..45df259bb7 100644
--- a/third_party/emdawnwebgpu/library_webgpu.js
+++ b/third_party/emdawnwebgpu/library_webgpu.js
@@ -811,6 +811,61 @@ var LibraryWebGPU = {
{{{ runtimeKeepalivePush() }}}
WebGPU.Internals.futureInsert(futureId, adapter.requestDevice(desc).then((device) => {
{{{ runtimeKeepalivePop() }}}
+
+ if (globalThis["WEBGPU_STAT"]) {
+ // a set that caches all active buffers
+ const buffers = WebGPU.Internals.buffers ??= new Set();
+ // key is buffer usage, value is total size of buffers with that usage
+ const buffersTotalSize = WebGPU.Internals.buffersTotalSize ??= new Map();
+
+ WebGPU.Internals.buffersCreated ??= 0;
+ WebGPU.Internals.buffersDestroyed ??= 0;
+ WebGPU.Internals.buffersUploads ??= 0;
+ WebGPU.Internals.buffersExternalUploads ??= 0;
+ WebGPU.Internals.buffersDownloads ??= 0;
+ WebGPU.Internals.buffersExternalDownloads ??= 0;
+
+ // create a proxy so that we can monitor buffer usages
+ device = new Proxy(device, {
+ // when call device.createBuffer(), the returned buffer should be added into buffers
+ get: (target, prop, _receiver) => {
+ if (prop === 'createBuffer') {
+ return (desc) => {
+ const buffer = target.createBuffer(desc);
+ const originalDestroy = buffer.destroy.bind(buffer);
+ buffer.destroy = () => {
+ const previousTotal = buffersTotalSize.get(buffer.usage);
+ buffersTotalSize.set(buffer.usage, previousTotal - buffer.size);
+ buffers.delete(buffer);
+ WebGPU.Internals.buffersDestroyed++;
+ originalDestroy();
+ };
+
+ if (buffer.usage === (GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC)) {
+ WebGPU.Internals.buffersUploads++;
+ }
+ if (buffer.usage === (GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ)) {
+ WebGPU.Internals.buffersDownloads++;
+ }
+
+ buffers.add(buffer);
+ const previousTotal = buffersTotalSize.get(buffer.usage) ?? 0;
+ buffersTotalSize.set(buffer.usage, previousTotal + buffer.size);
+ WebGPU.Internals.buffersCreated++;
+ return buffer;
+ };
+ }
+ const propertyValue = Reflect.get(target, prop);
+ if (typeof propertyValue === 'function') {
+ return propertyValue.bind(target);
+ } else {
+ return propertyValue;
+ }
+ },
+ set: (target, prop, value, _receiver) => Reflect.set(target, prop, value),
+ });
+ }
+
WebGPU.Internals.jsObjectInsert(queuePtr, device.queue);
WebGPU.Internals.jsObjectInsert(devicePtr, device);

diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp
index ca52b1237b..a30ca583c3 100644
--- a/third_party/emdawnwebgpu/webgpu.cpp
+++ b/third_party/emdawnwebgpu/webgpu.cpp
@@ -131,7 +131,6 @@ class RefCounted : NonMovable {
bool Release() {
if (mRefCount.fetch_sub(1u, std::memory_order_release) == 1u) {
std::atomic_thread_fence(std::memory_order_acquire);
- emwgpuDelete(this);
return true;
}
return false;
@@ -234,6 +233,7 @@ class Ref {
static void Release(T value) {
if (value != nullptr && value->RefCounted::Release()) {
delete value;
+ emwgpuDelete(value);
}
}

@@ -642,6 +642,7 @@ struct WGPUBufferImpl final : public EventSource,
public RefCountedWithExternalCount {
public:
WGPUBufferImpl(const EventSource* source, bool mappedAtCreation);
+ ~WGPUBufferImpl();

void Destroy();
const void* GetConstMappedRange(size_t offset, size_t size);
@@ -1164,11 +1165,17 @@ WGPUAdapter emwgpuCreateAdapter(const EventSource* source) {

WGPUBuffer emwgpuCreateBuffer(const EventSource* source,
bool mappedAtCreation = false) {
- return new WGPUBufferImpl(source, mappedAtCreation);
+ auto x = new WGPUBufferImpl(source, mappedAtCreation);
+ // printf(" #C++: emwgpuCreateBuffer %p\n", x);
+ return x;
}

WGPUDevice emwgpuCreateDevice(const EventSource* source, WGPUQueue queue) {
- return new WGPUDeviceImpl(source, queue);
+ // This function is only called from JS via `importJsDevice()`, which
+ // needs to increment the external ref count to fix the behavior.
+ WGPUDeviceImpl* device = new WGPUDeviceImpl(source, queue);
+ device->AddExternalRef();
+ return device;
}

WGPUQueue emwgpuCreateQueue(const EventSource* source) {
@@ -1284,6 +1291,10 @@ WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation)
}
}

+WGPUBufferImpl::~WGPUBufferImpl() {
+ Destroy();
+}
+
void WGPUBufferImpl::Destroy() {
emwgpuBufferDestroy(this);
AbortPendingMap("Buffer was destroyed before mapping was resolved.");
@@ -1504,6 +1515,7 @@ WGPUFuture WGPUShaderModuleImpl::GetCompilationInfo(
void wgpu##Name##Release(WGPU##Name o) { \
if (o->Release()) { \
delete o; \
+ emwgpuDelete(o); \
} \
}
WGPU_OBJECTS(DEFINE_WGPU_DEFAULT_ADDREF_RELEASE)
@@ -1587,6 +1599,7 @@ WGPUFuture wgpuAdapterRequestDevice(
// ----------------------------------------------------------------------------

void wgpuBufferDestroy(WGPUBuffer buffer) {
+ // printf(" #C++: wgpuBufferDestroy %p\n", buffer);
buffer->Destroy();
}

@@ -1639,6 +1652,7 @@ void wgpuBufferUnmap(WGPUBuffer buffer) {
WGPUBuffer wgpuDeviceCreateBuffer(WGPUDevice device,
const WGPUBufferDescriptor* descriptor) {
WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation);
+ // printf(" #C++: wgpuDeviceCreateBuffer %p\n", buffer);
emwgpuDeviceCreateBuffer(device, descriptor, buffer);
return buffer;
}
Loading
Loading