Skip to content

Commit 3f07980

Browse files
committed
Merge branch 'main' into transcribe-data-method
2 parents b9bc0e0 + 45afdc7 commit 3f07980

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3175
-1736
lines changed

android/src/main/java/com/rnwhisper/RNWhisper.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ protected void onPostExecute(Void result) {
235235
tasks.put(task, "abortTranscribe-" + id);
236236
}
237237

238+
public void bench(double id, double nThreads, Promise promise) {
239+
final WhisperContext context = contexts.get((int) id);
240+
if (context == null) {
241+
promise.reject("Context not found");
242+
return;
243+
}
244+
promise.resolve(context.bench((int) nThreads));
245+
}
246+
238247
public void releaseContext(double id, Promise promise) {
239248
final int contextId = (int) id;
240249
AsyncTask task = new AsyncTask<Void, Void, Void>() {

android/src/main/java/com/rnwhisper/WhisperContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ public void stopCurrentTranscribe() {
423423
stopTranscribe(this.jobId);
424424
}
425425

426+
public String bench(int n_threads) {
427+
return bench(context, n_threads);
428+
}
429+
426430
public void release() {
427431
stopCurrentTranscribe();
428432
freeContext(context);
@@ -527,4 +531,5 @@ protected static native int fullWithJob(
527531
int slice_index,
528532
int n_samples
529533
);
534+
protected static native String bench(long context, int n_threads);
530535
}

android/src/main/jni.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,4 +508,17 @@ Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
508508
return whisper_full_get_segment_speaker_turn_next(context, index);
509509
}
510510

511+
JNIEXPORT jstring JNICALL
512+
Java_com_rnwhisper_WhisperContext_bench(
513+
JNIEnv *env,
514+
jobject thiz,
515+
jlong context_ptr,
516+
jint n_threads
517+
) {
518+
UNUSED(thiz);
519+
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
520+
std::string result = rnwhisper::bench(context, n_threads);
521+
return env->NewStringUTF(result.c_str());
522+
}
523+
511524
} // extern "C"

android/src/newarch/java/com/rnwhisper/RNWhisperModule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
5757
rnwhisper.abortTranscribe(contextId, jobId, promise);
5858
}
5959

60+
@ReactMethod
61+
public void bench(double id, double nThreads, Promise promise) {
62+
rnwhisper.bench(id, nThreads, promise);
63+
}
64+
6065
@ReactMethod
6166
public void releaseContext(double id, Promise promise) {
6267
rnwhisper.releaseContext(id, promise);

android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
5757
rnwhisper.abortTranscribe(contextId, jobId, promise);
5858
}
5959

60+
@ReactMethod
61+
public void bench(double id, double nThreads, Promise promise) {
62+
rnwhisper.bench(id, nThreads, promise);
63+
}
64+
6065
@ReactMethod
6166
public void releaseContext(double id, Promise promise) {
6267
rnwhisper.releaseContext(id, promise);

cpp/ggml-metal.m

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
7676
ctx->mtl_device_ref_count--;
7777

7878
if (ctx->mtl_device_ref_count == 0) {
79+
[ctx->mtl_device release];
7980
ctx->mtl_device = nil;
8081
}
8182
}
@@ -520,8 +521,10 @@ @implementation WSPGGMLMetalClass
520521
struct wsp_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
521522
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
522523
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
524+
[metal_function release]; \
523525
if (error) { \
524526
WSP_GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
527+
[metal_library release]; \
525528
return NULL; \
526529
} \
527530
} else { \
@@ -723,12 +726,24 @@ @implementation WSPGGMLMetalClass
723726
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
724727
}
725728

729+
[metal_library release];
730+
726731
return ctx;
727732
}
728733

729734
static void wsp_ggml_metal_free(struct wsp_ggml_backend_metal_context * ctx) {
730735
WSP_GGML_LOG_INFO("%s: deallocating\n", __func__);
731736

737+
for (int i = 0; i < WSP_GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
738+
[ctx->kernels[i].pipeline release];
739+
}
740+
741+
Block_release(ctx->encode_async);
742+
743+
[ctx->queue release];
744+
745+
dispatch_release(ctx->d_queue);
746+
732747
free(ctx);
733748
}
734749

@@ -3241,6 +3256,9 @@ static enum wsp_ggml_status wsp_ggml_metal_graph_compute(
32413256
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
32423257
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
32433258

3259+
for (int i = 0; i < ctx->n_buffers; i++) {
3260+
[ctx->buffers[i].metal release];
3261+
}
32443262
wsp_ggml_backend_metal_device_rel(buffer->buft->device->context);
32453263

32463264
if (ctx->owned) {
@@ -3534,7 +3552,11 @@ static void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb
35343552
}
35353553
}
35363554

3537-
ctx->encode_async = ^(size_t iter) {
3555+
if (ctx->encode_async) {
3556+
Block_release(ctx->encode_async);
3557+
}
3558+
3559+
ctx->encode_async = Block_copy(^(size_t iter) {
35383560
const int cb_idx = iter;
35393561
const int n_cb_l = ctx->n_cb;
35403562

@@ -3573,7 +3595,7 @@ static void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb
35733595
if (cb_idx < 2 || ctx->abort_callback == NULL) {
35743596
[command_buffer commit];
35753597
}
3576-
};
3598+
});
35773599
}
35783600

35793601
static struct wsp_ggml_backend_i wsp_ggml_backend_metal_i = {

cpp/rn-whisper.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,97 @@
88

99
namespace rnwhisper {
1010

11+
const char * system_info(void) {
12+
static std::string s;
13+
s = "";
14+
if (wsp_ggml_cpu_has_avx() == 1) s += "AVX ";
15+
if (wsp_ggml_cpu_has_avx2() == 1) s += "AVX2 ";
16+
if (wsp_ggml_cpu_has_avx512() == 1) s += "AVX512 ";
17+
if (wsp_ggml_cpu_has_fma() == 1) s += "FMA ";
18+
if (wsp_ggml_cpu_has_neon() == 1) s += "NEON ";
19+
if (wsp_ggml_cpu_has_arm_fma() == 1) s += "ARM_FMA ";
20+
if (wsp_ggml_cpu_has_metal() == 1) s += "METAL ";
21+
if (wsp_ggml_cpu_has_f16c() == 1) s += "F16C ";
22+
if (wsp_ggml_cpu_has_fp16_va() == 1) s += "FP16_VA ";
23+
if (wsp_ggml_cpu_has_blas() == 1) s += "BLAS ";
24+
if (wsp_ggml_cpu_has_sse3() == 1) s += "SSE3 ";
25+
if (wsp_ggml_cpu_has_ssse3() == 1) s += "SSSE3 ";
26+
if (wsp_ggml_cpu_has_vsx() == 1) s += "VSX ";
27+
#ifdef WHISPER_USE_COREML
28+
s += "COREML ";
29+
#endif
30+
s.erase(s.find_last_not_of(" ") + 1);
31+
return s.c_str();
32+
}
33+
34+
std::string bench(struct whisper_context * ctx, int n_threads) {
35+
const int n_mels = whisper_model_n_mels(ctx);
36+
37+
if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
38+
return "error: failed to set mel: " + std::to_string(ret);
39+
}
40+
// heat encoder
41+
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
42+
return "error: failed to encode: " + std::to_string(ret);
43+
}
44+
45+
whisper_token tokens[512];
46+
memset(tokens, 0, sizeof(tokens));
47+
48+
// prompt heat
49+
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
50+
return "error: failed to decode: " + std::to_string(ret);
51+
}
52+
53+
// text-generation heat
54+
if (int ret = whisper_decode(ctx, tokens, 1, 256, n_threads) != 0) {
55+
return "error: failed to decode: " + std::to_string(ret);
56+
}
57+
58+
whisper_reset_timings(ctx);
59+
60+
// actual run
61+
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
62+
return "error: failed to encode: " + std::to_string(ret);
63+
}
64+
65+
// text-generation
66+
for (int i = 0; i < 256; i++) {
67+
if (int ret = whisper_decode(ctx, tokens, 1, i, n_threads) != 0) {
68+
return "error: failed to decode: " + std::to_string(ret);
69+
}
70+
}
71+
72+
// batched decoding
73+
for (int i = 0; i < 64; i++) {
74+
if (int ret = whisper_decode(ctx, tokens, 5, 0, n_threads) != 0) {
75+
return "error: failed to decode: " + std::to_string(ret);
76+
}
77+
}
78+
79+
// prompt processing
80+
for (int i = 0; i < 16; i++) {
81+
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
82+
return "error: failed to decode: " + std::to_string(ret);
83+
}
84+
}
85+
86+
const struct whisper_timings * timings = whisper_get_timings(ctx);
87+
88+
const int32_t n_encode = std::max(1, timings->n_encode);
89+
const int32_t n_decode = std::max(1, timings->n_decode);
90+
const int32_t n_batchd = std::max(1, timings->n_batchd);
91+
const int32_t n_prompt = std::max(1, timings->n_prompt);
92+
93+
return std::string("[") +
94+
"\"" + system_info() + "\"," +
95+
std::to_string(n_threads) + "," +
96+
std::to_string(1e-3f * timings->t_encode_us / n_encode) + "," +
97+
std::to_string(1e-3f * timings->t_decode_us / n_decode) + "," +
98+
std::to_string(1e-3f * timings->t_batchd_us / n_batchd) + "," +
99+
std::to_string(1e-3f * timings->t_prompt_us / n_prompt) + "]";
100+
}
101+
11102
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
12103
const float rc = 1.0f / (2.0f * M_PI * cutoff);
13104
const float dt = 1.0f / sample_rate;

cpp/rn-whisper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
namespace rnwhisper {
1111

12+
std::string bench(whisper_context * ctx, int n_threads);
13+
1214
struct vad_params {
1315
bool use_vad = false;
1416
float vad_thold = 0.6f;

cpp/whisper.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
41904190
return ctx->vocab.token_transcribe;
41914191
}
41924192

4193+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4194+
if (ctx->state == nullptr) {
4195+
return nullptr;
4196+
}
4197+
return new whisper_timings {
4198+
.load_us = ctx->t_load_us,
4199+
.t_start_us = ctx->t_start_us,
4200+
.fail_p = ctx->state->n_fail_p,
4201+
.fail_h = ctx->state->n_fail_h,
4202+
.t_mel_us = ctx->state->t_mel_us,
4203+
.n_sample = ctx->state->n_sample,
4204+
.n_encode = ctx->state->n_encode,
4205+
.n_decode = ctx->state->n_decode,
4206+
.n_batchd = ctx->state->n_batchd,
4207+
.n_prompt = ctx->state->n_prompt,
4208+
.t_sample_us = ctx->state->t_sample_us,
4209+
.t_encode_us = ctx->state->t_encode_us,
4210+
.t_decode_us = ctx->state->t_decode_us,
4211+
.t_batchd_us = ctx->state->t_batchd_us,
4212+
.t_prompt_us = ctx->state->t_prompt_us,
4213+
};
4214+
}
4215+
41934216
void whisper_print_timings(struct whisper_context * ctx) {
41944217
const int64_t t_end_us = wsp_ggml_time_us();
4218+
const struct whisper_timings * timings = whisper_get_timings(ctx);
41954219

41964220
WHISPER_LOG_INFO("\n");
4197-
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4221+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
41984222
if (ctx->state != nullptr) {
4199-
42004223
const int32_t n_sample = std::max(1, ctx->state->n_sample);
42014224
const int32_t n_encode = std::max(1, ctx->state->n_encode);
42024225
const int32_t n_decode = std::max(1, ctx->state->n_decode);
42034226
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
42044227
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
42054228

4206-
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4207-
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4208-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4209-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4210-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4211-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4212-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4229+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
4230+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
4231+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
4232+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
4233+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
4234+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
4235+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
42134236
}
4214-
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4237+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
42154238
}
42164239

42174240
void whisper_reset_timings(struct whisper_context * ctx) {

cpp/whisper.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,24 @@ extern "C" {
424424
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
425425

426426
// Performance information from the default state.
427+
struct whisper_timings {
428+
int64_t load_us;
429+
int64_t t_start_us;
430+
int32_t fail_p;
431+
int32_t fail_h;
432+
int64_t t_mel_us;
433+
int32_t n_sample;
434+
int32_t n_encode;
435+
int32_t n_decode;
436+
int32_t n_batchd;
437+
int32_t n_prompt;
438+
int64_t t_sample_us;
439+
int64_t t_encode_us;
440+
int64_t t_decode_us;
441+
int64_t t_batchd_us;
442+
int64_t t_prompt_us;
443+
};
444+
WHISPER_API struct whisper_timings * whisper_get_timings(struct whisper_context * ctx);
427445
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
428446
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
429447

0 commit comments

Comments
 (0)