Skip to content

Commit

Permalink
Add public callbacks to help expose internal state a little more (#240)
Browse files Browse the repository at this point in the history
* SegmentDiscovery callback

* ModelState callback

* FractionCompleted callback

* TranscriptionPhaseCallback callback

* Updates for review

* Formatting

* Remove remaining callback from init

---------

Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
iandundas and ZachNagengast authored Nov 5, 2024
1 parent dd2eb73 commit 03f0bb4
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
41 changes: 41 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,47 @@ public struct TranscriptionProgress {
}
}

// Callbacks to receive state updates during transcription.

/// A callback that provides transcription segments as they are discovered.
/// - Parameters:
/// - segments: An array of `TranscriptionSegment` objects representing the transcribed segments
public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void

/// A callback that reports changes in the model's state.
/// - Parameters:
/// - oldState: The previous state of the model, if any
/// - newState: The current state of the model
public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void

/// A callback that reports changes in the transcription process.
/// - Parameter state: The current `TranscriptionState` of the transcription process
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void

/// Represents the different states of the transcription process.
public enum TranscriptionState: CustomStringConvertible {
/// The audio is being converted to the required format for transcription
case convertingAudio

/// The audio is actively being transcribed to text
case transcribing

/// The transcription process has completed
case finished

/// A human-readable description of the transcription state
public var description: String {
switch self {
case .convertingAudio:
return "Converting Audio"
case .transcribing:
return "Transcribing"
case .finished:
return "Finished"
}
}
}

/// Callback to receive progress updates during transcription.
///
/// - Parameters:
Expand Down
4 changes: 4 additions & 0 deletions Sources/WhisperKit/Core/TranscribeTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ final class TranscribeTask {
private let textDecoder: any TextDecoding
private let tokenizer: any WhisperTokenizer

public var segmentDiscoveryCallback: SegmentDiscoveryCallback?

init(
currentTimings: TranscriptionTimings,
progress: Progress?,
Expand Down Expand Up @@ -230,6 +232,8 @@ final class TranscribeTask {
}
}

segmentDiscoveryCallback?(currentSegments)

// add them to the `allSegments` list
allSegments.append(contentsOf: currentSegments)
let allCurrentTokens = currentSegments.flatMap { $0.tokens }
Expand Down
35 changes: 31 additions & 4 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import Tokenizers
open class WhisperKit {
/// Models
public private(set) var modelVariant: ModelVariant = .tiny
public private(set) var modelState: ModelState = .unloaded
public private(set) var modelState: ModelState = .unloaded {
didSet {
modelStateCallback?(oldValue, modelState)
}
}

public var modelCompute: ModelComputeOptions
public var tokenizer: WhisperTokenizer?

Expand Down Expand Up @@ -42,6 +47,11 @@ open class WhisperKit {
public var tokenizerFolder: URL?
public private(set) var useBackgroundDownloadSession: Bool

/// Callbacks
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
public var modelStateCallback: ModelStateCallback?
public var transcriptionStateCallback: TranscriptionStateCallback?

public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws {
modelCompute = config.computeOptions ?? ModelComputeOptions()
audioProcessor = config.audioProcessor ?? AudioProcessor()
Expand Down Expand Up @@ -365,7 +375,7 @@ open class WhisperKit {
} else {
currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart
}

Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s")
}

Expand All @@ -378,13 +388,13 @@ open class WhisperKit {
computeUnits: modelCompute.audioEncoderCompute,
prewarmMode: prewarmMode
)

if prewarmMode {
currentTimings.encoderSpecializationTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
} else {
currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
}

Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s")
}

Expand Down Expand Up @@ -549,6 +559,8 @@ open class WhisperKit {
decodeOptions: DecodingOptions? = nil,
callback: TranscriptionCallback = nil
) async -> [Result<[TranscriptionResult], Swift.Error>] {
transcriptionStateCallback?(.convertingAudio)

// Start timing the audio loading and conversion process
let loadAudioStart = Date()

Expand All @@ -561,6 +573,11 @@ open class WhisperKit {
currentTimings.audioLoading = loadAndConvertTime
Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)")

transcriptionStateCallback?(.transcribing)
defer {
transcriptionStateCallback?(.finished)
}

// Transcribe the loaded audio arrays
let transcribeResults = await transcribeWithResults(
audioArrays: audioArrays,
Expand Down Expand Up @@ -733,6 +750,8 @@ open class WhisperKit {
decodeOptions: DecodingOptions? = nil,
callback: TranscriptionCallback = nil
) async throws -> [TranscriptionResult] {
transcriptionStateCallback?(.convertingAudio)

// Process input audio file into audio samples
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
let convertAudioStart = Date()
Expand All @@ -746,6 +765,12 @@ open class WhisperKit {
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
}

transcriptionStateCallback?(.transcribing)
defer {
transcriptionStateCallback?(.finished)
}

// Send converted samples to be transcribed
let transcribeResults: [TranscriptionResult] = try await transcribe(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down Expand Up @@ -872,6 +897,8 @@ open class WhisperKit {
tokenizer: tokenizer
)

transcribeTask.segmentDiscoveryCallback = self.segmentDiscoveryCallback

let transcribeTaskResult = try await transcribeTask.run(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down
37 changes: 37 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,43 @@ final class UnitTests: XCTestCase {
XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country.")
}

func testCallbacks() async throws {
let config = try WhisperKitConfig(
modelFolder: tinyModelPath(),
verbose: true,
logLevel: .debug,
load: false
)
let whisperKit = try await WhisperKit(config)
let modelStateExpectation = XCTestExpectation(description: "Model state callback expectation")
whisperKit.modelStateCallback = { (oldState: ModelState?, newState: ModelState) in
Logging.debug("Model state: \(newState)")
modelStateExpectation.fulfill()
}

let segmentDiscoveryExpectation = XCTestExpectation(description: "Segment discovery callback expectation")
whisperKit.segmentDiscoveryCallback = { (segments: [TranscriptionSegment]) in
Logging.debug("Segments discovered: \(segments)")
segmentDiscoveryExpectation.fulfill()
}

let transcriptionStateExpectation = XCTestExpectation(description: "Transcription state callback expectation")
whisperKit.transcriptionStateCallback = { (state: TranscriptionState) in
Logging.debug("Transcription state: \(state)")
transcriptionStateExpectation.fulfill()
}

// Run the full pipeline
try await whisperKit.loadModels()
let audioFilePath = try XCTUnwrap(
Bundle.current.path(forResource: "jfk", ofType: "wav"),
"Audio file not found"
)
let _ = try await whisperKit.transcribe(audioPath: audioFilePath)

await fulfillment(of: [modelStateExpectation, segmentDiscoveryExpectation, transcriptionStateExpectation], timeout: 1)
}

// MARK: - Utils Tests

func testFillIndexesWithValue() throws {
Expand Down

0 comments on commit 03f0bb4

Please sign in to comment.