Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge][Test] For blank skipping in sherpa-onnx test #437

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ endif()
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON)
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/offline-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct OfflineModelConfig {

std::string tokens;
int32_t num_threads = 2;
bool debug = false;
bool debug = true;
std::string provider = "cpu";

// With the help of this field, we only need to load the model once
Expand All @@ -37,7 +37,7 @@ struct OfflineModelConfig {
// - nemo_ctc. It is a NeMo CTC model.
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
std::string model_type = "transducer";

OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
Expand Down
9 changes: 8 additions & 1 deletion sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
-23.025850929940457f);

auto t = model_->RunEncoder(std::move(x), std::move(x_length));

// Ort::Value encoder_out = t.first;
auto ctc_out = model_->RunCTC(Clone(model_->Allocator(), &(t.first)));
auto frame_reducer_out = model_->RunFrameReducer(std::move(t.first), std::move(t.second), std::move(ctc_out));

// auto results =
// decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
auto results =
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
decoder_->Decode(std::move(frame_reducer_out.first), std::move(frame_reducer_out.second), ss, n);

int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
Expand Down
10 changes: 8 additions & 2 deletions sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);

auto projected_encoder_out = model_->RunEncoderProj(Clone(model_->Allocator(), &(packed_encoder_out.data)));

int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());

Expand All @@ -38,11 +40,15 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
auto decoder_input = model_->BuildDecoderInput(ans, ans.size());
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));

auto projected_decoder_out = model_->RunDecoderProj(std::move(decoder_out));

int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
// Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
Ort::Value cur_encoder_out = Slice(model_->Allocator(), &projected_encoder_out, start, start + n);
// Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
Ort::Value cur_decoder_out = Slice(model_->Allocator(), &projected_decoder_out, 0, n);
start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out));
Expand Down
36 changes: 35 additions & 1 deletion sherpa-onnx/csrc/offline-transducer-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
po->Register("ctc", &ctc, "Path to ctc.onnx");
po->Register("frame_reducer", &frame_reducer, "Path to frame_reducer.onnx");
po->Register("encoder_proj", &encoder_proj, "Path to encoder_proj.onnx");
po->Register("decoder_proj", &decoder_proj, "Path to decoder_proj.onnx");
}

bool OfflineTransducerModelConfig::Validate() const {
Expand All @@ -35,6 +39,32 @@ bool OfflineTransducerModelConfig::Validate() const {
return false;
}

if (!ctc.empty())
{
if (!FileExists(ctc)) {
SHERPA_ONNX_LOGE("ctc: %s does not exist", ctc.c_str());
return false;
}

if (!FileExists(frame_reducer)) {
SHERPA_ONNX_LOGE("frame_reducer: %s does not exist", frame_reducer.c_str());
return false;
}
}

if (!encoder_proj.empty())
{
if (!FileExists(encoder_proj)) {
SHERPA_ONNX_LOGE("encoder_proj: %s does not exist", encoder_proj.c_str());
return false;
}

if (!FileExists(decoder_proj)) {
SHERPA_ONNX_LOGE("decoder_proj: %s does not exist", decoder_proj.c_str());
return false;
}
}

return true;
}

Expand All @@ -44,7 +74,11 @@ std::string OfflineTransducerModelConfig::ToString() const {
os << "OfflineTransducerModelConfig(";
os << "encoder_filename=\"" << encoder_filename << "\", ";
os << "decoder_filename=\"" << decoder_filename << "\", ";
os << "joiner_filename=\"" << joiner_filename << "\")";
os << "joiner_filename=\"" << joiner_filename << "\", ";
os << "ctc=\"" << ctc << "\", ";
os << "frame_reducer=\"" << frame_reducer << "\", ";
os << "encoder_proj=\"" << encoder_proj << "\", ";
os << "decoder_proj=\"" << decoder_proj << "\")";

return os.str();
}
Expand Down
16 changes: 14 additions & 2 deletions sherpa-onnx/csrc/offline-transducer-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@ struct OfflineTransducerModelConfig {
std::string encoder_filename;
std::string decoder_filename;
std::string joiner_filename;
std::string ctc;
std::string frame_reducer;
std::string encoder_proj;
std::string decoder_proj;

OfflineTransducerModelConfig() = default;
OfflineTransducerModelConfig(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename)
const std::string &joiner_filename,
const std::string &ctc="",
const std::string &frame_reducer="",
const std::string &encoder_proj="",
const std::string &decoder_proj="")
: encoder_filename(encoder_filename),
decoder_filename(decoder_filename),
joiner_filename(joiner_filename) {}
joiner_filename(joiner_filename),
ctc(ctc),
frame_reducer(frame_reducer),
encoder_proj(encoder_proj),
decoder_proj(decoder_proj) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
Loading
Loading