From f2eef103d7c39e536c6e749dae1904fe749b45f7 Mon Sep 17 00:00:00 2001 From: Max Desiatov Date: Wed, 15 Nov 2023 10:10:02 +0000 Subject: [PATCH] Support reentrant function call between host/guest Since `Runtime` manages the single execution state, it wrongly executes upper frame instructions if the call stack is `guest -> host -> guest`. This change creates a fresh execution state for each exported guest function invocation and adds test coverage to verify that it fixes the issue. Co-authored-by: Yuta Saito --- .../Execution/Runtime/ExecutionState.swift | 2 +- .../WasmKit/Execution/Runtime/Function.swift | 3 +- Sources/WasmKit/Execution/Runtime/Store.swift | 5 +- .../Execution/HostModuleTests.swift | 60 +++++++++++++++++++ 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/Sources/WasmKit/Execution/Runtime/ExecutionState.swift b/Sources/WasmKit/Execution/Runtime/ExecutionState.swift index eb40c39c..18f89f6e 100644 --- a/Sources/WasmKit/Execution/Runtime/ExecutionState.swift +++ b/Sources/WasmKit/Execution/Runtime/ExecutionState.swift @@ -84,7 +84,7 @@ extension ExecutionState { switch try runtime.store.function(at: address) { case let .host(function): let parameters = try stack.popValues(count: function.type.parameters.count) - let caller = Caller(store: runtime.store, instance: stack.currentFrame.module) + let caller = Caller(runtime: runtime, instance: stack.currentFrame.module) stack.push(values: try function.implementation(caller, parameters)) programCounter += 1 diff --git a/Sources/WasmKit/Execution/Runtime/Function.swift b/Sources/WasmKit/Execution/Runtime/Function.swift index 57edea41..1fc3dbce 100644 --- a/Sources/WasmKit/Execution/Runtime/Function.swift +++ b/Sources/WasmKit/Execution/Runtime/Function.swift @@ -1,3 +1,4 @@ +/// A WebAssembly guest function or host function public struct Function: Equatable { internal let address: FunctionAddress @@ -16,7 +17,7 @@ public struct Function: Equatable { let parameters = try execution.stack.popValues(count: function.type.parameters.count) - let caller = Caller(store: runtime.store, instance: execution.stack.currentFrame.module) + let caller = Caller(runtime: runtime, instance: execution.stack.currentFrame.module) let results = try function.implementation(caller, parameters) try check(functionType: function.type, results: results) execution.stack.push(values: results) diff --git a/Sources/WasmKit/Execution/Runtime/Store.swift b/Sources/WasmKit/Execution/Runtime/Store.swift index 7ca72236..25d76491 100644 --- a/Sources/WasmKit/Execution/Runtime/Store.swift +++ b/Sources/WasmKit/Execution/Runtime/Store.swift @@ -60,8 +60,11 @@ public final class Store { /// A caller context passed to host functions public struct Caller { - public let store: Store + public let runtime: Runtime public let instance: ModuleInstance + public var store: Store { + runtime.store + } } /// A host-defined function which can be imported by a WebAssembly module instance. diff --git a/Tests/WasmKitTests/Execution/HostModuleTests.swift b/Tests/WasmKitTests/Execution/HostModuleTests.swift index 1d936a25..456de681 100644 --- a/Tests/WasmKitTests/Execution/HostModuleTests.swift +++ b/Tests/WasmKitTests/Execution/HostModuleTests.swift @@ -18,4 +18,64 @@ final class HostModuleTests: XCTestCase { // Ensure the allocated address is valid _ = runtime.store.memory(at: memoryAddr) } + + func testReentrancy() throws { + let runtime = Runtime() + let voidSignature = FunctionType(parameters: [], results: []) + let module = Module( + types: [voidSignature], + functions: [ + // [0] (import "env" "bar" func) + // [1] (import "env" "qux" func) + // [2] "foo" + GuestFunction( + type: 0, locals: [], + body: [ + .control(.call(functionIndex: 0)), + .control(.call(functionIndex: 0)), + .control(.call(functionIndex: 0)), + ]), + // [3] "bar" + GuestFunction( + type: 0, locals: [], + body: [ + .control(.call(functionIndex: 1)) + ]), + ], + imports: [ + Import(module: "env", name: "bar", descriptor: .function(0)), + Import(module: "env", name: "qux", descriptor: .function(0)), + ], + exports: [ + Export(name: "foo", descriptor: .function(2)), + Export(name: "baz", descriptor: .function(3)), + ] + ) + + var isExecutingFoo = false + var isQuxCalled = false + let hostModule = HostModule( + functions: [ + "bar": HostFunction(type: voidSignature) { caller, _ in + // Ensure "invoke" executes instructions under the current call + XCTAssertFalse(isExecutingFoo, "bar should not be called recursively") + isExecutingFoo = true + defer { isExecutingFoo = false } + let foo = try XCTUnwrap(caller.instance.exportedFunction(name: "baz")) + _ = try foo.invoke([], runtime: caller.runtime) + return [] + }, + "qux": HostFunction(type: voidSignature) { _, _ in + XCTAssertTrue(isExecutingFoo) + isQuxCalled = true + return [] + }, + ] + ) + try runtime.store.register(hostModule, as: "env") + let instance = try runtime.instantiate(module: module) + // Check foo(wasm) -> bar(host) -> baz(wasm) -> qux(host) + _ = try runtime.invoke(instance, function: "foo") + XCTAssertTrue(isQuxCalled) + } }