@@ -14,6 +14,8 @@ import androidx.compose.runtime.mutableStateListOf
1414import androidx.lifecycle.ViewModel
1515import androidx.lifecycle.viewModelScope
1616import com.argmaxinc.whisperkit.ExperimentalWhisperKit
17+ import com.argmaxinc.whisperkit.TranscriptionResult
18+ import com.argmaxinc.whisperkit.TranscriptionSegment
1719import com.argmaxinc.whisperkit.WhisperKit
1820import com.argmaxinc.whisperkit.WhisperKit.TextOutputCallback
1921import com.argmaxinc.whisperkit.WhisperKitException
@@ -33,22 +35,13 @@ import java.text.SimpleDateFormat
3335import java.util.Date
3436import java.util.Locale
3537
36- data class TranscriptionSegment (
37- val text : String ,
38- val start : Float ,
39- val end : Float ,
40- val tokens : List <Int > = emptyList(),
41- )
42-
43- data class TranscriptionResult (
44- val text : String = " " ,
45- val segments : List <TranscriptionSegment > = emptyList(),
46- )
47-
4838@OptIn(ExperimentalWhisperKit ::class )
4939class WhisperViewModel : ViewModel () {
5040 companion object {
5141 const val TAG = " WhisperViewModel"
42+
43+ // Models currently supporting NPU backend, don't enable NPU for other models
44+ val MODELS_SUPPORTING_NPU = listOf (WhisperKit .Builder .QUALCOMM_TINY_EN , WhisperKit .Builder .QUALCOMM_BASE_EN )
5245 }
5346
5447 private lateinit var appContext: Context
@@ -190,25 +183,25 @@ class WhisperViewModel : ViewModel() {
190183 cacheDir = context.cacheDir.absolutePath
191184 }
192185
193- fun onTextOutput (what : Int , timestamp : Float , msg : String ) {
186+ fun onTextOutput (what : Int , result : TranscriptionResult ) {
187+ val segments = result.segments
194188 when (what) {
195189 TextOutputCallback .MSG_INIT -> {
196- Log .i(MainActivity .TAG , " TFLite initialized: $msg " )
190+ Log .i(MainActivity .TAG , " TFLite initialized: ${result.text} " )
197191 startTime = System .currentTimeMillis()
198192 _pipelineStart .value = startTime.toDouble() / 1000.0
199193 _isInitializing .value = false
200194 }
201195
202196 TextOutputCallback .MSG_TEXT_OUT -> {
203197 Log .i(MainActivity .TAG , " TEXT OUT THREAD" )
204- if (msg .isNotEmpty()) {
198+ if (segments .isNotEmpty()) {
205199 if (! firstTokenReceived) {
206200 firstTokenReceived = true
207201 firstTokenTimestamp = System .currentTimeMillis()
208202 _firstTokenTime .value = (firstTokenTimestamp - startTime).toDouble() / 1000.0
209203 }
210-
211- val newTokens = msg.length / 4
204+ val newTokens = segments.joinToString(" " ) { it.text }.length / 4
212205 totalTokens + = newTokens
213206
214207 val currentTime = System .currentTimeMillis()
@@ -220,14 +213,14 @@ class WhisperViewModel : ViewModel() {
220213 }
221214
222215 lastTokenTimestamp = currentTime
223- updateTranscript(msg )
216+ updateTranscript(segments )
224217 }
225218 }
226219
227220 TextOutputCallback .MSG_CLOSE -> {
228221 Log .i(MainActivity .TAG , " Transcription completed." )
229- if (msg .isNotEmpty()) {
230- val newTokens = msg .length / 4
222+ if (segments .isNotEmpty()) {
223+ val newTokens = segments.joinToString( " " ) { it.text } .length / 4
231224 totalTokens + = newTokens
232225
233226 val totalTime = (System .currentTimeMillis() - startTime).toDouble() / 1000.0
@@ -236,8 +229,7 @@ class WhisperViewModel : ViewModel() {
236229
237230 updateRealtimeMetrics(totalTime)
238231 }
239-
240- updateTranscript(msg)
232+ updateTranscript(segments)
241233 }
242234 }
243235
@@ -247,25 +239,8 @@ class WhisperViewModel : ViewModel() {
247239 }
248240 }
249241
250- private fun updateTranscript (chunkText : String , withTimestamps : Boolean = false) {
251- var processedText = chunkText
252-
253- val timestamps = if (withTimestamps) {
254- val timestampPattern = " <\\ |(\\ d+\\ .\\ d+)\\ |>" .toRegex()
255- val timestampMatches = timestampPattern.findAll(chunkText).toList()
256- timestampMatches.map { it.groupValues[1 ].toFloat() }
257- } else {
258- emptyList()
259- }
260-
261- if (! withTimestamps) {
262- processedText = processedText
263- .replace(" <\\ |[^>]*\\ |>" .toRegex(), " " )
264- .trim()
265- } else {
266- processedText = processedText.trim()
267- }
268-
242+ private fun updateTranscript (segments : List <TranscriptionSegment >) {
243+ val processedText = segments.joinToString(" " ) { it.text }
269244 if (processedText.isNotEmpty()) {
270245 if (allText.isNotEmpty()) {
271246 allText.append(" \n " )
@@ -284,13 +259,12 @@ class WhisperViewModel : ViewModel() {
284259 fun listModels () {
285260 viewModelScope.launch {
286261 val modelDirs = listOf (
287- // TODO: enable when models are ready
288- // WhisperKit.Builder.OPENAI_TINY_EN,
289- // WhisperKit.Builder.OPENAI_BASE_EN,
290- // WhisperKit.Builder.OPENAI_SMALL_EN,
291262 WhisperKit .Builder .QUALCOMM_TINY_EN ,
292263 WhisperKit .Builder .QUALCOMM_BASE_EN ,
293- // WhisperKit.Builder.QUALCOMM_SMALL_EN
264+ WhisperKit .Builder .OPENAI_TINY_EN ,
265+ WhisperKit .Builder .OPENAI_BASE_EN ,
266+ WhisperKit .Builder .OPENAI_TINY ,
267+ WhisperKit .Builder .OPENAI_BASE ,
294268 )
295269 availableModels.clear()
296270 availableModels.addAll(modelDirs)
@@ -364,6 +338,21 @@ class WhisperViewModel : ViewModel() {
364338
365339 fun selectModel (model : String ) {
366340 _selectedModel .value = model
341+ if (model in MODELS_SUPPORTING_NPU ) {
342+ _encoderComputeUnits .update {
343+ ComputeUnits .CPU_AND_NPU
344+ }
345+ _decoderComputeUnits .update {
346+ ComputeUnits .CPU_AND_NPU
347+ }
348+ } else {
349+ _encoderComputeUnits .update {
350+ ComputeUnits .CPU_ONLY
351+ }
352+ _decoderComputeUnits .update {
353+ ComputeUnits .CPU_ONLY
354+ }
355+ }
367356 _modelState .value = ModelState .UNLOADED
368357 _encoderState .value = ModelState .UNLOADED
369358 _decoderState .value = ModelState .UNLOADED
0 commit comments