From 668e06ffe35254434f35c4d501e97246cd221a12 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 3 Nov 2022 21:33:11 +0800 Subject: [PATCH] fix lattice length of rnnt_decode --- k2/csrc/rnnt_decode.cu | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 1224d7b58..c3996756b 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -81,6 +81,14 @@ void RnntDecodingStreams::TerminateAndFlushToStreams() { NVTX_RANGE(K2_FUNC); // return directly if already detached or no frames decoded. if (!attached_ || prev_frames_.empty()) return; + + // We do this extra Advance to get arcs point to the super-final state. + const Array2 dummy_logprobs(c_, + states_.TotSize(1), + config_.vocab_size, + 0); + Advance(dummy_logprobs); + std::vector> states; std::vector> scores; Unstack(states_, 0, &states); @@ -682,16 +690,19 @@ void RnntDecodingStreams::GatherPrevFrames( Array1 stream2t_row_splits(GetCpuContext(), num_frames.size() + 1); for (size_t i = 0; i < num_frames.size(); ++i) { - stream2t_row_splits.Data()[i] = num_frames[i]; - K2_CHECK_LE(num_frames[i], + // + 1 for the last dummy_logprobs. + stream2t_row_splits.Data()[i] = num_frames[i] + 1; + K2_CHECK_LE(num_frames[i] + 1, static_cast(srcs_[i]->prev_frames.size())); - for (int32_t j = 0; j < num_frames[i]; ++j) { + + // + 1 for the last dummy_logprobs. + for (int32_t j = 0; j < num_frames[i] + 1; ++j) { frames_ptr.push_back(srcs_[i]->prev_frames[j].get()); } } // frames has a shape of [t][state][arc], - // its Dim0() equals std::sum(num_frames) + // its Dim0() equals std::sum(num_frames) + num_frames.size() auto frames = Stack(0, frames_ptr.size(), frames_ptr.data()); stream2t_row_splits = stream2t_row_splits.To(c_);