Skip to content

Commit

Permalink
fix: segfault when accessing bindings from C++
Browse files Browse the repository at this point in the history
Move implementation from C++ to Python.

Signed-off-by: Aaron <[email protected]>
  • Loading branch information
aarnphm committed Mar 1, 2023
1 parent 735ca92 commit 1e64d43
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 259 deletions.
30 changes: 22 additions & 8 deletions src/whispercpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .utils import download_model

if t.TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray

from . import api
else:
api = LazyLoader("api", globals(), "whispercpp.api")
Expand All @@ -27,20 +30,31 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
context: api.Context
params: api.Params

@classmethod
def from_pretrained(cls, model_name: str):
@staticmethod
def from_pretrained(model_name: str):
if model_name not in MODELS_URL:
raise RuntimeError(
f"'{model_name}' is not a valid preconverted model. Choose one of {list(MODELS_URL)}"
)
_ref = object.__new__(cls)
_cpp_binding = api.WhisperPreTrainedModel(download_model(model_name))
context = _cpp_binding.context
params = _cpp_binding.params
transcribe = _cpp_binding.transcribe
del cls, _cpp_binding
_ref = object.__new__(Whisper)
context = api.Context.from_file(download_model(model_name))
params = api.Params.from_sampling_strategy(
api.SamplingStrategies.from_strategy_type(api.SAMPLING_GREEDY)
)
params.print_progress = False
params.print_realtime = False
context.reset_timings()
_ref.__dict__.update(locals())
return _ref

def transcribe(self, data: NDArray[np.float32], num_proc: int = 1):
self.context.full_parallel(self.params, data, num_proc)
return "".join(
[
self.context.full_get_segment_text(i)
for i in range(self.context.full_n_segments())
]
)


__all__ = ["Whisper", "api"]
20 changes: 8 additions & 12 deletions src/whispercpp/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ from __future__ import annotations
import enum
import typing as t

import numpy as np
from numpy.typing import NDArray

SAMPLE_RATE: int = ...
Expand All @@ -19,14 +18,19 @@ class SamplingBeamSearchStrategy:
beam_size: int
patience: float

SAMPLING_GREEDY: StrategyType = ...
SAMPLING_BEAM_SEARCH: StrategyType = ...

class StrategyType(enum.Enum):
GREEDY = ...
BEAM_SEARCH = ...
SAMPLING_GREEDY = ...
SAMPLING_BEAM_SEARCH = ...

class SamplingStrategies:
type: StrategyType
greedy: SamplingGreedyStrategy
beam_search: SamplingBeamSearchStrategy
@staticmethod
def from_strategy_type(strategy_type: StrategyType) -> SamplingStrategies: ...

# annotate the type of whisper_full_params
_CppFullParams = t.Any
Expand Down Expand Up @@ -82,12 +86,4 @@ class Context:
def full_get_segment_text(self, segment: int) -> str: ...
def full_n_segments(self) -> int: ...
def free(self) -> None: ...

class WhisperPreTrainedModel:
context: Context
params: Params
@t.overload
def __init__(self) -> None: ...
@t.overload
def __init__(self, path: str | bytes) -> None: ...
def transcribe(self, arr: NDArray[np.float32], num_proc: int) -> str: ...
def reset_timings(self) -> None: ...
181 changes: 114 additions & 67 deletions src/whispercpp/api_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,6 @@
#include <sstream>

namespace whisper {

struct new_segment_callback_data {
std::vector<std::string> *results;
};

class Whisper {
public:
~Whisper() = default;
Whisper(const char *model_path) {
Context context = Context::from_file(model_path);
this->context = context;

// Set default params to recommended.
FullParams params = FullParams::from_sampling_strategy(
SamplingStrategies::from_strategy_type(SamplingStrategies::GREEDY));
// disable printing progress
params.set_print_progress(false);
// disable realtime print, using callback
params.set_print_realtime(false);
// invoke new_segment_callback for faster transcription.
params.set_new_segment_callback([](struct whisper_context *ctx, int n_new,
void *user_data) {
const auto &results = ((new_segment_callback_data *)user_data)->results;

const int n_segments = whisper_full_n_segments(ctx);

// print the last n_new segments
const int s0 = n_segments - n_new;

for (int i = s0; i < n_segments; i++) {
const char *text = whisper_full_get_segment_text(ctx, i);
results->push_back(std::move(text));
};
});
this->params = params;
}

std::string transcribe(std::vector<float> data, int num_proc) {
std::vector<std::string> results;
results.reserve(data.size());
new_segment_callback_data user_data = {&results};
params.set_new_segment_callback_user_data(&user_data);
if (context.full_parallel(params, data, num_proc) != 0) {
throw std::runtime_error("transcribe failed");
}

const char *const delim = "";
// We are allocating a new string for every element in the vector.
// This is not efficient, for larger files.
return std::accumulate(results.begin(), results.end(), std::string(delim));
};

Context get_context() { return context; };
FullParams get_params() { return params; };

private:
Context context;
FullParams params;
};

PYBIND11_MODULE(api, m) {
m.doc() = "Python interface for whisper.cpp";

Expand All @@ -79,12 +19,119 @@ PYBIND11_MODULE(api, m) {
ExportContextApi(m);

// NOTE: export Params API
ExportParamsApi(m);

py::class_<Whisper>(m, "WhisperPreTrainedModel")
.def(py::init<const char *>())
.def_property_readonly("context", &Whisper::get_context)
.def_property_readonly("params", &Whisper::get_params)
.def("transcribe", &Whisper::transcribe, "data"_a, "num_proc"_a = 1);
py::enum_<SamplingStrategies::StrategyType>(m, "StrategyType")
.value("SAMPLING_GREEDY", SamplingStrategies::GREEDY)
.value("SAMPLING_BEAM_SEARCH", SamplingStrategies::BEAM_SEARCH)
.export_values();

py::class_<SamplingGreedy>(m, "SamplingGreedyStrategy")
.def(py::init<>())
.def_property(
"best_of", [](SamplingGreedy &self) { return self.best_of; },
[](SamplingGreedy &self, int best_of) { self.best_of = best_of; });

py::class_<SamplingBeamSearch>(m, "SamplingBeamSearchStrategy")
.def(py::init<>())
.def_property(
"beam_size", [](SamplingBeamSearch &self) { return self.beam_size; },
[](SamplingBeamSearch &self, int beam_size) {
self.beam_size = beam_size;
})
.def_property(
"patience", [](SamplingBeamSearch &self) { return self.patience; },
[](SamplingBeamSearch &self, float patience) {
self.patience = patience;
});

py::class_<SamplingStrategies>(m, "SamplingStrategies",
"Available sampling strategy for whisper")
.def_static("from_strategy_type", &SamplingStrategies::from_strategy_type,
"strategy"_a)
.def_property(
"type", [](SamplingStrategies &self) { return self.type; },
[](SamplingStrategies &self, SamplingStrategies::StrategyType type) {
self.type = type;
})
.def_property(
"greedy", [](SamplingStrategies &self) { return self.greedy; },
[](SamplingStrategies &self, SamplingGreedy greedy) {
self.greedy = greedy;
})
.def_property(
"beam_search",
[](SamplingStrategies &self) { return self.beam_search; },
[](SamplingStrategies &self, SamplingBeamSearch beam_search) {
self.beam_search = beam_search;
});

py::class_<FullParams>(m, "Params", "Whisper parameters container")
.def_static("from_sampling_strategy", &FullParams::from_sampling_strategy,
"sampling_strategy"_a)
.def_property("num_threads", &FullParams::get_n_threads,
&FullParams::set_n_threads)
.def_property("num_max_text_ctx", &FullParams::get_n_max_text_ctx,
&FullParams::set_n_max_text_ctx)
.def_property("offset_ms", &FullParams::get_offset_ms,
&FullParams::set_offset_ms)
.def_property("duration_ms", &FullParams::get_duration_ms,
&FullParams::set_duration_ms)
.def_property("translate", &FullParams::get_translate,
&FullParams::set_translate)
.def_property("no_context", &FullParams::get_no_context,
&FullParams::set_no_context)
.def_property("single_segment", &FullParams::get_single_segment,
&FullParams::set_single_segment)
.def_property("print_special", &FullParams::get_print_special,
&FullParams::set_print_special)
.def_property("print_progress", &FullParams::get_print_progress,
&FullParams::set_print_progress)
.def_property("print_realtime", &FullParams::get_print_realtime,
&FullParams::set_print_realtime)
.def_property("print_timestamps", &FullParams::get_print_timestamps,
&FullParams::set_print_timestamps)
.def_property("token_timestamps", &FullParams::get_token_timestamps,
&FullParams::set_token_timestamps)
.def_property("timestamp_token_probability_threshold",
&FullParams::get_thold_pt, &FullParams::set_thold_pt)
.def_property("timestamp_token_sum_probability_threshold",
&FullParams::get_thold_ptsum, &FullParams::set_thold_ptsum)
.def_property("max_segment_length", &FullParams::get_max_len,
&FullParams::set_max_len)
.def_property("split_on_word", &FullParams::get_split_on_word,
&FullParams::set_split_on_word)
.def_property("max_tokens", &FullParams::get_max_tokens,
&FullParams::set_max_tokens)
.def_property("speed_up", &FullParams::get_speed_up,
&FullParams::set_speed_up)
.def_property("audio_ctx", &FullParams::get_audio_ctx,
&FullParams::set_audio_ctx)
.def("set_tokens", &FullParams::set_tokens, "tokens"_a)
.def_property_readonly("prompt_tokens", &FullParams::get_prompt_tokens)
.def_property_readonly("prompt_num_tokens",
&FullParams::get_prompt_n_tokens)
.def_property("language", &FullParams::get_language,
&FullParams::set_language)
.def_property("suppress_blank", &FullParams::get_suppress_blank,
&FullParams::set_suppress_blank)
.def_property("suppress_none_speech_tokens",
&FullParams::get_suppress_none_speech_tokens,
&FullParams::set_suppress_none_speech_tokens)
.def_property("temperature", &FullParams::get_temperature,
&FullParams::set_temperature)
.def_property("max_intial_timestamps", &FullParams::get_max_intial_ts,
&FullParams::set_max_intial_ts)
.def_property("length_penalty", &FullParams::get_length_penalty,
&FullParams::set_length_penalty)
.def_property("temperature_inc", &FullParams::get_temperature_inc,
&FullParams::set_temperature_inc)
.def_property("entropy_threshold", &FullParams::get_entropy_thold,
&FullParams::set_entropy_thold)
.def_property("logprob_threshold", &FullParams::get_logprob_thold,
&FullParams::set_logprob_thold)
.def_property("no_speech_threshold", &FullParams::get_no_speech_thold,
&FullParams::set_no_speech_thold);
// TODO: idk what to do with setting all the callbacks for FullParams. API are
// there, but need more time investingating conversion from Python callback to
// C++ callback
}
}; // namespace whisper
Loading

0 comments on commit 1e64d43

Please sign in to comment.