Skip to content

Commit

Permalink
feat: Create statically compiled Swift closure wrapper to avoid C-sty…
Browse files Browse the repository at this point in the history
…le function pointers (#440)

* feat: Create Swift closure wrapper

* It kinda works

* Move it into lambda and make it `mutable`

* fix: Import `NitroModules` and alias `bridge`

* fix: Also generate Promise resolver funcs

* fix: Use closureWrapper here again

* __

* fix: Fix these Promise things

* fix: Void

* fix: Fix wrapping C++ funcs to make them callable in Swift

* fix: Remove getPromise()

* perf: Make it inline

* comments
  • Loading branch information
mrousavy authored Dec 20, 2024
1 parent 3a1b7b8 commit 699e138
Show file tree
Hide file tree
Showing 20 changed files with 1,073 additions and 616 deletions.
108 changes: 40 additions & 68 deletions packages/nitrogen/src/syntax/swift/SwiftCxxBridgedType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { createSwiftVariant, getSwiftVariantCaseName } from './SwiftVariant.js'
import { VoidType } from '../types/VoidType.js'
import { NamedWrappingType } from '../types/NamedWrappingType.js'
import { ErrorType } from '../types/ErrorType.js'
import { createSwiftFunctionBridge } from './SwiftFunction.js'

// TODO: Remove enum bridge once Swift fixes bidirectional enums crashing the `-Swift.h` header.

Expand Down Expand Up @@ -176,10 +177,24 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
files.push(extensionFile)
break
}
case 'function': {
const functionType = getTypeAs(this.type, FunctionType)
const extensionFile = createSwiftFunctionBridge(functionType)
files.push(extensionFile)
break
}
case 'promise': {
// Promise needs resolver and rejecter funcs in Swift
const promiseType = getTypeAs(this.type, PromiseType)
files.push(createSwiftFunctionBridge(promiseType.resolverFunction))
files.push(createSwiftFunctionBridge(promiseType.rejecterFunction))
break
}
case 'variant': {
const variant = getTypeAs(this.type, VariantType)
const file = createSwiftVariant(variant)
files.push(file)
break
}
}

Expand Down Expand Up @@ -345,23 +360,27 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
const promise = getTypeAs(this.type, PromiseType)
switch (language) {
case 'swift': {
const bridge = this.getBridgeOrThrow()
if (promise.resultingType.kind === 'void') {
// It's void - resolve()
const resolverFunc = new FunctionType(new VoidType(), [])
const rejecterFunc = new FunctionType(new VoidType(), [
new NamedWrappingType('error', new ErrorType()),
])
const resolverFuncBridge = new SwiftCxxBridgedType(resolverFunc)
const rejecterFuncBridge = new SwiftCxxBridgedType(rejecterFunc)
return `
{ () -> ${promise.getCode('swift')} in
let __promise = ${promise.getCode('swift')}()
let __resolver = SwiftClosure { __promise.resolve(withResult: ()) }
let __resolver = { __promise.resolve(withResult: ()) }
let __rejecter = { (__error: Error) in
__promise.reject(withError: __error)
}
let __resolverCpp = __resolver.getFunctionCopy()
let __resolverCpp = ${indent(resolverFuncBridge.parseFromSwiftToCpp('__resolver', 'swift'), ' ')}
let __rejecterCpp = ${indent(rejecterFuncBridge.parseFromSwiftToCpp('__rejecter', 'swift'), ' ')}
${cppParameterName}.addOnResolvedListener(__resolverCpp)
${cppParameterName}.addOnRejectedListener(__rejecterCpp)
let __promiseHolder = bridge.wrap_${bridge.specializationName}(${cppParameterName})
__promiseHolder.addOnResolvedListener(__resolverCpp)
__promiseHolder.addOnRejectedListener(__rejecterCpp)
return __promise
}()`.trim()
} else {
Expand Down Expand Up @@ -389,8 +408,9 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
}
let __resolverCpp = ${indent(resolverFuncBridge.parseFromSwiftToCpp('__resolver', 'swift'), ' ')}
let __rejecterCpp = ${indent(rejecterFuncBridge.parseFromSwiftToCpp('__rejecter', 'swift'), ' ')}
${cppParameterName}.${resolverFuncName}(__resolverCpp)
${cppParameterName}.addOnRejectedListener(__rejecterCpp)
let __promiseHolder = bridge.wrap_${bridge.specializationName}(${cppParameterName})
__promiseHolder.${resolverFuncName}(__resolverCpp)
__promiseHolder.addOnRejectedListener(__rejecterCpp)
return __promise
}()`.trim()
}
Expand Down Expand Up @@ -532,18 +552,21 @@ case ${i}:
if (funcType.returnType.kind === 'void') {
return `
{ () -> ${swiftClosureType} in
let __sharedClosure = bridge.share_${bridge.specializationName}(${cppParameterName})
let __wrappedFunction = bridge.wrap_${bridge.specializationName}(${cppParameterName})
return { ${signature} in
__sharedClosure.pointee.call(${indent(paramsForward.join(', '), ' ')})
__wrappedFunction.call(${indent(paramsForward.join(', '), ' ')})
}
}()`.trim()
} else {
const resultBridged = new SwiftCxxBridgedType(funcType.returnType)
const resultBridged = new SwiftCxxBridgedType(
funcType.returnType,
true
)
return `
{ () -> ${swiftClosureType} in
let __sharedClosure = bridge.share_${bridge.specializationName}(${cppParameterName})
let __wrappedFunction = bridge.wrap_${bridge.specializationName}(${cppParameterName})
return { ${signature} in
let __result = __sharedClosure.pointee.call(${indent(paramsForward.join(', '), ' ')})
let __result = __wrappedFunction.call(${indent(paramsForward.join(', '), ' ')})
return ${indent(resultBridged.parseFromCppToSwift('__result', 'swift'), ' ')}
}
}()`.trim()
Expand Down Expand Up @@ -661,30 +684,20 @@ case ${i}:
true
)
switch (language) {
case 'c++':
if (this.isBridgingToDirectCppTarget) {
return swiftParameterName
} else {
return `${swiftParameterName}.getPromise()`
}
case 'swift':
const arg =
promise.resultingType.kind === 'void'
? ''
: resolvingType.parseFromSwiftToCpp('__result', 'swift')
const code = `
return `
{ () -> bridge.${bridge.specializationName} in
let __promise = ${makePromise}()
let __promiseHolder = bridge.wrap_${bridge.specializationName}(__promise)
${swiftParameterName}
.then({ __result in __promise.resolve(${indent(arg, ' ')}) })
.catch({ __error in __promise.reject(__error.toCpp()) })
.then({ __result in __promiseHolder.resolve(${indent(arg, ' ')}) })
.catch({ __error in __promiseHolder.reject(__error.toCpp()) })
return __promise
}()`.trim()
if (this.isBridgingToDirectCppTarget) {
return `${code}.getPromise()`
} else {
return code
}
default:
return swiftParameterName
}
Expand Down Expand Up @@ -772,52 +785,11 @@ case ${i}:
switch (language) {
case 'swift': {
const bridge = this.getBridgeOrThrow()
const func = getTypeAs(this.type, FunctionType)
const cFuncParamsForward = func.parameters
.map((p) => {
const bridged = new SwiftCxxBridgedType(p)
return bridged.parseFromCppToSwift(
`__${p.escapedName}`,
'swift'
)
})
.join(', ')
const paramsSignature = func.parameters
.map((p) => `_ __${p.escapedName}: ${p.getCode('swift')}`)
.join(', ')
const paramsForward = func.parameters
.map((p) => `__${p.escapedName}`)
.join(', ')
const cFuncParamsSignature = [
'__closureHolder: UnsafeMutableRawPointer',
...func.parameters.map((p) => {
const bridged = new SwiftCxxBridgedType(p)
return `__${p.escapedName}: ${bridged.getTypeCode('swift')}`
}),
].join(', ')
const createFunc = `bridge.${bridge.funcName}`
return `
{ () -> bridge.${bridge.specializationName} in
final class ClosureHolder {
let closure: ${func.getCode('swift')}
init(wrappingClosure closure: @escaping ${func.getCode('swift')}) {
self.closure = closure
}
func invoke(${paramsSignature}) {
self.closure(${indent(paramsForward, ' ')})
}
}
let __closureHolder = Unmanaged.passRetained(ClosureHolder(wrappingClosure: ${swiftParameterName})).toOpaque()
func __callClosure(${cFuncParamsSignature}) -> Void {
let closure = Unmanaged<ClosureHolder>.fromOpaque(__closureHolder).takeUnretainedValue()
closure.invoke(${indent(cFuncParamsForward, ' ')})
}
func __destroyClosure(_ __closureHolder: UnsafeMutableRawPointer) -> Void {
Unmanaged<ClosureHolder>.fromOpaque(__closureHolder).release()
}
return ${createFunc}(__closureHolder, __callClosure, __destroyClosure)
let __closureWrapper = ${bridge.specializationName}(${swiftParameterName})
return ${createFunc}(__closureWrapper.toUnsafe())
}()
`.trim()
}
Expand Down
98 changes: 51 additions & 47 deletions packages/nitrogen/src/syntax/swift/SwiftCxxTypeHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,58 +303,47 @@ function createCxxFunctionSwiftHelper(type: FunctionType): SwiftCxxHelper {
return `${p.getCode('c++')} ${p.escapedName}`
}
})
const callCppFuncParamsSignature = type.parameters.map((p) => {
const paramsForward = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
const cppType = bridge.getTypeCode('c++')
return `${cppType} ${p.escapedName}`
return bridge.parseFromCppToSwift(p.escapedName, 'c++')
})
const name = type.specializationName
const wrapperName = `${name}_Wrapper`
const swiftClassName = `${NitroConfig.getIosModuleName()}::${type.specializationName}`

const callParamsForward = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.parseFromSwiftToCpp(p.escapedName, 'c++')
})
const paramsForward = [
'sharedClosureHolder.get()',
...type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.parseFromCppToSwift(p.escapedName, 'c++')
}),
]
const callFuncReturnType = returnBridge.getTypeCode('c++')
const callFuncParams = [
'void* _Nonnull /* closureHolder */',
...type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.getTypeCode('c++')
}),
]
const functionPointerParam = `${callFuncReturnType}(* _Nonnull call)(${callFuncParams.join(', ')})`
const name = type.specializationName
const wrapperName = `${name}_Wrapper`

const callFuncReturnType = returnBridge.getTypeCode('c++')
const callCppFuncParamsSignature = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
const cppType = bridge.getTypeCode('c++')
return `${cppType} ${p.escapedName}`
})
let callCppFuncBody: string
if (returnBridge.hasType) {
callCppFuncBody = `
auto __result = _function(${callParamsForward.join(', ')});
auto __result = _function->operator()(${callParamsForward.join(', ')});
return ${returnBridge.parseFromCppToSwift('__result', 'c++')};
`.trim()
} else {
callCppFuncBody = `_function(${callParamsForward.join(', ')});`
callCppFuncBody = `_function->operator()(${callParamsForward.join(', ')});`
}

let callSwiftFuncBody: string
if (returnBridge.hasType) {
callSwiftFuncBody = `
auto __result = call(${paramsForward.join(', ')});
let body: string
if (type.returnType.kind === 'void') {
body = `
swiftClosure.call(${paramsForward.join(', ')});
`.trim()
} else {
body = `
auto __result = swiftClosure.call(${paramsForward.join(', ')});
return ${returnBridge.parseFromSwiftToCpp('__result', 'c++')};
`.trim()
} else {
callSwiftFuncBody = `call(${paramsForward.join(', ')});`
}

// TODO: Remove shared_Func_void(...) function that returns a std::shared_ptr<std::function<...>>
// once Swift fixes the bug where a regular std::function cannot be captured.
// https://github.com/swiftlang/swift/issues/76143

return {
cxxType: actualType,
funcName: `create_${name}`,
Expand All @@ -370,22 +359,16 @@ using ${name} = ${actualType};
*/
class ${wrapperName} final {
public:
explicit ${wrapperName}(const ${actualType}& func): _function(func) {}
explicit ${wrapperName}(${actualType}&& func): _function(std::move(func)) {}
explicit ${wrapperName}(${actualType}&& func): _function(std::make_shared<${actualType}>(std::move(func))) {}
inline ${callFuncReturnType} call(${callCppFuncParamsSignature.join(', ')}) const {
${indent(callCppFuncBody, ' ')}
}
private:
${actualType} _function;
} SWIFT_NONCOPYABLE;
inline ${name} create_${name}(void* _Nonnull closureHolder, ${functionPointerParam}, void(* _Nonnull destroy)(void* _Nonnull)) {
std::shared_ptr<void> sharedClosureHolder(closureHolder, destroy);
return ${name}([sharedClosureHolder = std::move(sharedClosureHolder), call](${paramsSignature.join(', ')}) -> ${type.returnType.getCode('c++')} {
${indent(callSwiftFuncBody, ' ')}
});
}
inline std::shared_ptr<${wrapperName}> share_${name}(const ${name}& value) {
return std::make_shared<${wrapperName}>(value);
std::shared_ptr<${actualType}> _function;
};
${name} create_${name}(void* _Nonnull swiftClosureWrapper);
inline ${wrapperName} wrap_${name}(${name} value) {
return ${wrapperName}(std::move(value));
}
`.trim(),
requiredIncludes: [
Expand All @@ -402,6 +385,24 @@ inline std::shared_ptr<${wrapperName}> share_${name}(const ${name}& value) {
...bridgedType.getRequiredImports(),
],
},
cxxImplementation: {
code: `
${name} create_${name}(void* _Nonnull swiftClosureWrapper) {
auto swiftClosure = ${swiftClassName}::fromUnsafe(swiftClosureWrapper);
return [swiftClosure = std::move(swiftClosure)](${paramsSignature.join(', ')}) mutable -> ${type.returnType.getCode('c++')} {
${indent(body, ' ')}
};
}
`.trim(),
requiredIncludes: [
{
language: 'c++',
// Swift umbrella header
name: getUmbrellaHeaderName(),
space: 'user',
},
],
},
dependencies: [],
}
}
Expand Down Expand Up @@ -566,7 +567,7 @@ ${functions.join('\n')}
function createCxxPromiseSwiftHelper(type: PromiseType): SwiftCxxHelper {
const resultingType = type.resultingType.getCode('c++')
const bridgedType = new SwiftCxxBridgedType(type)
const actualType = `PromiseHolder<${resultingType}>`
const actualType = `std::shared_ptr<Promise<${resultingType}>>`

const resolverArgs: NamedType[] = []
if (type.resultingType.kind !== 'void') {
Expand All @@ -589,7 +590,10 @@ function createCxxPromiseSwiftHelper(type: PromiseType): SwiftCxxHelper {
*/
using ${name} = ${actualType};
inline ${actualType} create_${name}() {
return PromiseHolder<${resultingType}>::create();
return Promise<${resultingType}>::create();
}
inline PromiseHolder<${resultingType}> wrap_${name}(${actualType} promise) {
return PromiseHolder<${resultingType}>(std::move(promise));
}
`.trim(),
requiredIncludes: [
Expand Down
Loading

0 comments on commit 699e138

Please sign in to comment.