@@ -14,6 +14,8 @@ import androidx.compose.runtime.mutableStateListOf
14
14
import androidx.lifecycle.ViewModel
15
15
import androidx.lifecycle.viewModelScope
16
16
import com.argmaxinc.whisperkit.ExperimentalWhisperKit
17
+ import com.argmaxinc.whisperkit.TranscriptionResult
18
+ import com.argmaxinc.whisperkit.TranscriptionSegment
17
19
import com.argmaxinc.whisperkit.WhisperKit
18
20
import com.argmaxinc.whisperkit.WhisperKit.TextOutputCallback
19
21
import com.argmaxinc.whisperkit.WhisperKitException
@@ -33,22 +35,13 @@ import java.text.SimpleDateFormat
33
35
import java.util.Date
34
36
import java.util.Locale
35
37
36
- data class TranscriptionSegment (
37
- val text : String ,
38
- val start : Float ,
39
- val end : Float ,
40
- val tokens : List <Int > = emptyList(),
41
- )
42
-
43
- data class TranscriptionResult (
44
- val text : String = " " ,
45
- val segments : List <TranscriptionSegment > = emptyList(),
46
- )
47
-
48
38
@OptIn(ExperimentalWhisperKit ::class )
49
39
class WhisperViewModel : ViewModel () {
50
40
companion object {
51
41
const val TAG = " WhisperViewModel"
42
+
43
+ // Models currently supporting NPU backend, don't enable NPU for other models
44
+ val MODELS_SUPPORTING_NPU = listOf (WhisperKit .Builder .QUALCOMM_TINY_EN , WhisperKit .Builder .QUALCOMM_BASE_EN )
52
45
}
53
46
54
47
private lateinit var appContext: Context
@@ -190,25 +183,25 @@ class WhisperViewModel : ViewModel() {
190
183
cacheDir = context.cacheDir.absolutePath
191
184
}
192
185
193
- fun onTextOutput (what : Int , timestamp : Float , msg : String ) {
186
+ fun onTextOutput (what : Int , result : TranscriptionResult ) {
187
+ val segments = result.segments
194
188
when (what) {
195
189
TextOutputCallback .MSG_INIT -> {
196
- Log .i(MainActivity .TAG , " TFLite initialized: $msg " )
190
+ Log .i(MainActivity .TAG , " TFLite initialized: ${result.text} " )
197
191
startTime = System .currentTimeMillis()
198
192
_pipelineStart .value = startTime.toDouble() / 1000.0
199
193
_isInitializing .value = false
200
194
}
201
195
202
196
TextOutputCallback .MSG_TEXT_OUT -> {
203
197
Log .i(MainActivity .TAG , " TEXT OUT THREAD" )
204
- if (msg .isNotEmpty()) {
198
+ if (segments .isNotEmpty()) {
205
199
if (! firstTokenReceived) {
206
200
firstTokenReceived = true
207
201
firstTokenTimestamp = System .currentTimeMillis()
208
202
_firstTokenTime .value = (firstTokenTimestamp - startTime).toDouble() / 1000.0
209
203
}
210
-
211
- val newTokens = msg.length / 4
204
+ val newTokens = segments.joinToString(" " ) { it.text }.length / 4
212
205
totalTokens + = newTokens
213
206
214
207
val currentTime = System .currentTimeMillis()
@@ -220,14 +213,14 @@ class WhisperViewModel : ViewModel() {
220
213
}
221
214
222
215
lastTokenTimestamp = currentTime
223
- updateTranscript(msg )
216
+ updateTranscript(segments )
224
217
}
225
218
}
226
219
227
220
TextOutputCallback .MSG_CLOSE -> {
228
221
Log .i(MainActivity .TAG , " Transcription completed." )
229
- if (msg .isNotEmpty()) {
230
- val newTokens = msg .length / 4
222
+ if (segments .isNotEmpty()) {
223
+ val newTokens = segments.joinToString( " " ) { it.text } .length / 4
231
224
totalTokens + = newTokens
232
225
233
226
val totalTime = (System .currentTimeMillis() - startTime).toDouble() / 1000.0
@@ -236,8 +229,7 @@ class WhisperViewModel : ViewModel() {
236
229
237
230
updateRealtimeMetrics(totalTime)
238
231
}
239
-
240
- updateTranscript(msg)
232
+ updateTranscript(segments)
241
233
}
242
234
}
243
235
@@ -247,25 +239,8 @@ class WhisperViewModel : ViewModel() {
247
239
}
248
240
}
249
241
250
- private fun updateTranscript (chunkText : String , withTimestamps : Boolean = false) {
251
- var processedText = chunkText
252
-
253
- val timestamps = if (withTimestamps) {
254
- val timestampPattern = " <\\ |(\\ d+\\ .\\ d+)\\ |>" .toRegex()
255
- val timestampMatches = timestampPattern.findAll(chunkText).toList()
256
- timestampMatches.map { it.groupValues[1 ].toFloat() }
257
- } else {
258
- emptyList()
259
- }
260
-
261
- if (! withTimestamps) {
262
- processedText = processedText
263
- .replace(" <\\ |[^>]*\\ |>" .toRegex(), " " )
264
- .trim()
265
- } else {
266
- processedText = processedText.trim()
267
- }
268
-
242
+ private fun updateTranscript (segments : List <TranscriptionSegment >) {
243
+ val processedText = segments.joinToString(" " ) { it.text }
269
244
if (processedText.isNotEmpty()) {
270
245
if (allText.isNotEmpty()) {
271
246
allText.append(" \n " )
@@ -284,13 +259,12 @@ class WhisperViewModel : ViewModel() {
284
259
fun listModels () {
285
260
viewModelScope.launch {
286
261
val modelDirs = listOf (
287
- // TODO: enable when models are ready
288
- // WhisperKit.Builder.OPENAI_TINY_EN,
289
- // WhisperKit.Builder.OPENAI_BASE_EN,
290
- // WhisperKit.Builder.OPENAI_SMALL_EN,
291
262
WhisperKit .Builder .QUALCOMM_TINY_EN ,
292
263
WhisperKit .Builder .QUALCOMM_BASE_EN ,
293
- // WhisperKit.Builder.QUALCOMM_SMALL_EN
264
+ WhisperKit .Builder .OPENAI_TINY_EN ,
265
+ WhisperKit .Builder .OPENAI_BASE_EN ,
266
+ WhisperKit .Builder .OPENAI_TINY ,
267
+ WhisperKit .Builder .OPENAI_BASE ,
294
268
)
295
269
availableModels.clear()
296
270
availableModels.addAll(modelDirs)
@@ -364,6 +338,21 @@ class WhisperViewModel : ViewModel() {
364
338
365
339
fun selectModel (model : String ) {
366
340
_selectedModel .value = model
341
+ if (model in MODELS_SUPPORTING_NPU ) {
342
+ _encoderComputeUnits .update {
343
+ ComputeUnits .CPU_AND_NPU
344
+ }
345
+ _decoderComputeUnits .update {
346
+ ComputeUnits .CPU_AND_NPU
347
+ }
348
+ } else {
349
+ _encoderComputeUnits .update {
350
+ ComputeUnits .CPU_ONLY
351
+ }
352
+ _decoderComputeUnits .update {
353
+ ComputeUnits .CPU_ONLY
354
+ }
355
+ }
367
356
_modelState .value = ModelState .UNLOADED
368
357
_encoderState .value = ModelState .UNLOADED
369
358
_decoderState .value = ModelState .UNLOADED
0 commit comments