From 92af8eab3880da03d5659175e0a645c05a9bc698 Mon Sep 17 00:00:00 2001 From: Aaron Taylor Date: Mon, 15 May 2023 14:46:59 -0400 Subject: [PATCH] Fix unmanaged self retain missing corrisponding release (caused memory leak) --- Sources/SwiftWhisper/Whisper.swift | 49 +++++++++++++++++++----------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/Sources/SwiftWhisper/Whisper.swift b/Sources/SwiftWhisper/Whisper.swift index 48f6752..520c8b9 100644 --- a/Sources/SwiftWhisper/Whisper.swift +++ b/Sources/SwiftWhisper/Whisper.swift @@ -3,6 +3,7 @@ import whisper_cpp public class Whisper { private let whisperContext: OpaquePointer + private var unmanagedSelf: Unmanaged? public var delegate: WhisperDelegate? public var params: WhisperParams @@ -14,8 +15,6 @@ public class Whisper { public init(fromFileURL fileURL: URL, withParams params: WhisperParams = .default) { self.whisperContext = fileURL.relativePath.withCString { whisper_init_from_file($0) } self.params = params - - prepareCallbacks() } public init(fromData data: Data, withParams params: WhisperParams = .default) { @@ -23,8 +22,6 @@ public class Whisper { self.whisperContext = copy.withUnsafeMutableBytes { whisper_init_from_buffer($0.baseAddress!, data.count) } self.params = params - - prepareCallbacks() } deinit { @@ -38,8 +35,11 @@ public class Whisper { We can unwrap that and obtain a copy of self inside the callback. */ - params.new_segment_callback_user_data = Unmanaged.passRetained(self).toOpaque() - params.encoder_begin_callback_user_data = Unmanaged.passRetained(self).toOpaque() + cleanupCallbacks() + let unmanagedSelf = Unmanaged.passRetained(self) + self.unmanagedSelf = unmanagedSelf + params.new_segment_callback_user_data = unmanagedSelf.toOpaque() + params.encoder_begin_callback_user_data = unmanagedSelf.toOpaque() // swiftlint:disable line_length params.new_segment_callback = { (ctx: OpaquePointer?, _: OpaquePointer?, newSegmentCount: Int32, userData: UnsafeMutableRawPointer?) in @@ -94,32 +94,45 @@ public class Whisper { } } + private func cleanupCallbacks() { + guard let unmanagedSelf else { return } + + unmanagedSelf.release() + self.unmanagedSelf = nil + } + public func transcribe(audioFrames: [Float], completionHandler: @escaping (Result<[Segment], Error>) -> Void) { + prepareCallbacks() + + let wrappedCompletionHandler: (Result<[Segment], Error>) -> Void = { result in + self.cleanupCallbacks() + completionHandler(result) + } + guard !inProgress else { - completionHandler(.failure(WhisperError.instanceBusy)) + wrappedCompletionHandler(.failure(WhisperError.instanceBusy)) return } guard audioFrames.count > 0 else { - completionHandler(.failure(WhisperError.invalidFrames)) + wrappedCompletionHandler(.failure(WhisperError.invalidFrames)) return } inProgress = true frameCount = audioFrames.count - DispatchQueue.global(qos: .userInitiated).async { [unowned self] in - - whisper_full(whisperContext, params.whisperParams, audioFrames, Int32(audioFrames.count)) + DispatchQueue.global(qos: .userInitiated).async { + whisper_full(self.whisperContext, self.params.whisperParams, audioFrames, Int32(audioFrames.count)) - let segmentCount = whisper_full_n_segments(whisperContext) + let segmentCount = whisper_full_n_segments(self.whisperContext) var segments: [Segment] = [] segments.reserveCapacity(Int(segmentCount)) for index in 0..