diff --git a/src/nnet2/nnet-compute-online.cc b/src/nnet2/nnet-compute-online.cc index 768f6a245af..18fc48b6c78 100644 --- a/src/nnet2/nnet-compute-online.cc +++ b/src/nnet2/nnet-compute-online.cc @@ -42,6 +42,12 @@ void NnetOnlineComputer::Compute(const CuMatrixBase &input, if (input.NumRows() == 0) { output->Resize(0, 0); return; + } else { + // store the last frame as it might be needed for padding when Flush() is + // called. + if (last_seen_input_frame_.Dim() != input.NumCols()) + last_seen_input_frame_.Resize(input.NumCols()); + last_seen_input_frame_.CopyFromVec(input.Row(input.NumRows() - 1)); } // Checking if feature dimension matches that required by the neural network. @@ -100,8 +106,6 @@ void NnetOnlineComputer::Compute(const CuMatrixBase &input, nnet_.LeftContext() + nnet_.RightContext() + 1) { // we have sufficient frames to compute at least one nnet output nnet_.ComputeChunkInfo(num_effective_input_rows, 1, &chunk_info_); - // store the last frame as it might be needed for padding - last_seen_input_frame_ = input_data.Row(input_data.NumRows() - 1); Propagate(); *output = data_.back(); } else { @@ -110,6 +114,7 @@ void NnetOnlineComputer::Compute(const CuMatrixBase &input, // not enough input context so just return an empty array output->Resize(0, 0); } + } void NnetOnlineComputer::Flush(CuMatrix *output) { @@ -182,7 +187,7 @@ void NnetOnlineComputer::Propagate() { // Hence we manipulate the chunk_info objects to reflect the state of the // actual chunk, each component is computing, in the current Propagate. // As before we always assume the chunks are contiguous. - + // modifying the input chunk_info int32 chunk_size_assumed = chunk_info_[c].ChunkSize(); int32 last_offset = chunk_info_[c].GetOffset(chunk_size_assumed - 1); @@ -194,7 +199,7 @@ void NnetOnlineComputer::Propagate() { // modifying the output chunk_info chunk_size_assumed = chunk_info_[c + 1].ChunkSize(); last_offset = chunk_info_[c + 1].GetOffset(chunk_size_assumed - 1); - first_offset = last_offset - (input_data.NumRows() - + first_offset = last_offset - (input_data.NumRows() - (component.Context().back() - component.Context().front())) + 1; ChunkInfo output_chunk_info(chunk_info_[c + 1].NumCols(),