Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lattice length for rnnt decode #1237

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions k2/csrc/intersect_dense.cu
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ class MultiGraphDenseIntersect {
void DoStep(int32_t t) {
NVTX_RANGE(K2_FUNC);
Step &step = steps_[t], &prev_step = steps_[t - 1];
int32_t scores_num_cols = b_fsas_.scores.Dim1();
int32_t scores_num_cols = b_fsas_.scores.Dim1();
const float minus_inf = -std::numeric_limits<float>::infinity();

// Divide by two because each arc is repeated twice in arc_scores (once for
Expand Down Expand Up @@ -814,9 +814,10 @@ class MultiGraphDenseIntersect {
backward_dest_prob =
prev_state_scores_data[dest_state_scores_index_backward];

// Assign negative infinity (-inf) to both the forward and backward scores,
// if the label on the carc is out-of-range, i.e., the label in the decoding
// graph (a_fsas) does not exist in the neural-net output (b_fsas).
// Assign negative infinity (-inf) to both the forward and backward
// scores, if the label on the carc is out-of-range, i.e., the label
// in the decoding graph (a_fsas) does not exist in the neural-net
// output (b_fsas).
float b_score_forward;
float b_score_backward;
if (carc.label_plus_one <= scores_num_cols) {
Expand Down
192 changes: 183 additions & 9 deletions k2/csrc/rnnt_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() {
return unpruned_arcs_shape;
}

Renumbering RnntDecodingStreams::DoFisrtPassPruning(
Renumbering RnntDecodingStreams::DoFirstPassPruning(
RaggedShape &unpruned_arcs_shape, const Array2<float> &logprobs) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4);
Expand Down Expand Up @@ -439,7 +439,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
auto unpruned_arcs_shape = ExpandArcs();

// (2) Do initial pruning.
auto pass1_renumbering = DoFisrtPassPruning(unpruned_arcs_shape, logprobs);
auto pass1_renumbering = DoFirstPassPruning(unpruned_arcs_shape, logprobs);

// pass1_arcs_shape has a shape of [stream][context][state][arc]
auto pass1_arcs_shape =
Expand Down Expand Up @@ -489,6 +489,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
const auto logprobs_acc = logprobs.Accessor();
const Arc *const *graphs_arcs_data = graphs_.values.Data();


K2_EVAL(
c_, cur_num_arcs, lambda_populate_arcs_states_scores, (int32_t arc_idx) {
// Init renumber_arcs to 0, place here to save one kernel.
Expand All @@ -508,6 +509,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01],
num_graph_states = num_graph_states_data[idx0];
int64_t this_state = this_states_values_data[idx012];
int32_t this_graph_state = this_state % num_graph_states;
double this_score = this_scores_data[idx012];

// handle the implicit epsilon self-loop
Expand All @@ -516,7 +518,21 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
// we assume termination symbol to be 0 here.
scores_data[arc_idx] = this_score + logprobs_acc(idx01, 0);
ArcInfo ai;
ai.graph_arc_idx01 = -1;
/*
Track state index for self-loop arcs.
It's lucky that type int32_t has range [-2147483648, 2147483647]
there is one more negative values than positive values in computer.
state (0) --> graph_arc_idx01 (-1)
state (1) --> graph_arc_idx01 (-2)
state (2) --> graph_arc_idx01 (-3)
state (2147483647) --> graph_arc_idx01 (-2147483648)

Actually, super final state has no self-loop.
So definitely there are enough negative values
to represent positive state index.
*/
ai.graph_arc_idx01 = -(this_graph_state + 1);
K2_CHECK_LT(ai.graph_arc_idx01, 0);
ai.score = logprobs_acc(idx01, 0);
ai.label = 0;
arcs_data[arc_idx] = ai;
Expand All @@ -527,8 +543,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0];

int64_t this_context_state = this_state / num_graph_states;
int32_t this_graph_state = this_state % num_graph_states,
graph_idx0x = graph_row_split1_data[this_graph_state],
int32_t graph_idx0x = graph_row_split1_data[this_graph_state],
graph_idx01 = graph_idx0x + idx3 - 1; // minus 1 here as
// epsilon self-loop
// takes the position 0.
Expand Down Expand Up @@ -715,6 +730,162 @@ void RnntDecodingStreams::GatherPrevFrames(
}
}

void RnntDecodingStreams::GetFinalArcs() {
NVTX_RANGE(K2_FUNC);
/*
This function handles last two steps of the generated lattice.
Relationship of variables in these two steps are:

arcs: last frame arcs final arcs
states: {last frame state} ---------------> {final states} ---------> {super final state} # noqa

Suer final state has no leaving arcs.
*/

int32_t frames = prev_frames_.size();

// with shape [stream][context][state][arc]
auto last_frame_shape = prev_frames_[frames - 1]->shape;

// Note: last_frame_arc_data is non-const
// The original "dest_state" attribute for each element in last_frame_arc_data
// is state index processed by function GroupStatesByContexts.
// In this function, source states in last_frame is expanded again,
// and those expanded destination states are NOT grouped to save time.
// So "dest_state" should be re-assigned to a new value.
ArcInfo *last_frame_arc_data = prev_frames_[frames - 1]->values.Data();
const int32_t *lfs_row_ids3_data = last_frame_shape.RowIds(3).Data(),
*lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(),
*lfs_row_ids1_data = last_frame_shape.RowIds(1).Data(),
*lfs_row_splits3_data = last_frame_shape.RowSplits(3).Data(),
*lfs_row_splits2_data = last_frame_shape.RowSplits(2).Data(),
*lfs_row_splits1_data = last_frame_shape.RowSplits(1).Data();

const int32_t *num_graph_states_data = num_graph_states_.Data();
const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1);
const Arc *const *graphs_arcs_data = graphs_.values.Data();

// Name meaning of final_grpah_states:
// "final_" means it's for "final states".
// "_graph_states" means it storages state index in decoding graph.
// Though this variable could be calculated both in
// labmda_get_final_arcs_shape and lambda_populate_final_arcs,
// to save time, its calculated and cached during the former and
// used in the later.
Array1<int32_t> final_graph_states(c_, last_frame_shape.NumElements());
int32_t* final_graph_states_data = final_graph_states.Data();

// Calculate num_arcs for each final state.
Array1<int32_t> num_final_arcs(c_, last_frame_shape.NumElements() + 1);
int32_t *num_final_arcs_data = num_final_arcs.Data();

K2_EVAL(
c_, last_frame_shape.NumElements(), lambda_get_final_arcs_shape,
(int32_t idx0123) {
// place here to save one kernel.
num_final_arcs_data[idx0123] = 0;

int32_t idx012 = lfs_row_ids3_data[idx0123], // state_idx012
idx01 = lfs_row_ids2_data[idx012], // context_idx01
idx0 = lfs_row_ids1_data[idx01], // stream_idx0
arc_idx01x = lfs_row_splits2_data[idx01],
arc_idx01xx = lfs_row_splits3_data[arc_idx01x],
arc_idx23 = idx0123 - arc_idx01xx;

ArcInfo& ai = last_frame_arc_data[idx0123];

// Re-assign dest_state to a new value.
// See more detail comment at previous last_frame_arc_data definition.
ai.dest_state = arc_idx23;

if (ai.label == -1) {
// -(num_graph_states_data[idx0]) for state not expandable.
final_graph_states_data[idx0123] = -(num_graph_states_data[idx0]);
return;
}
int32_t dest_state = -1;
const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0];
const Arc *graph_arcs_data = graphs_arcs_data[idx0];
if (ai.graph_arc_idx01 < 0) {
// For implicit self-loop arcs.
dest_state = -ai.graph_arc_idx01 - 1;
K2_CHECK_GE(dest_state, 0);
K2_CHECK_LE(dest_state, num_graph_states_data[idx0]);
} else {
// For other arcs shown in the decoding graph.
dest_state = graph_arcs_data[ai.graph_arc_idx01].dest_state;
}
K2_CHECK_GE(dest_state, 0);

final_graph_states_data[idx0123] = dest_state;
// Plus one for the implicit epsilon self-loop.
num_final_arcs_data[idx0123] = graph_row_split1_data[dest_state + 1] -
graph_row_split1_data[dest_state] + 1;
});


ExclusiveSum(num_final_arcs, &num_final_arcs);

auto final_arcs_shape = RaggedShape2(&num_final_arcs, nullptr, -1);
final_arcs_shape = ComposeRaggedShapes(last_frame_shape, final_arcs_shape);
// [steam][context][state][arc][arc] --> [stream][context][arc][arc]
// could be viewd as [strem][context][final state][arc]
final_arcs_shape = RemoveAxis(final_arcs_shape, 2);
const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(),
*fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(),
*fas_row_ids3_data = final_arcs_shape.RowIds(3).Data(),
*fas_row_splits3_data = final_arcs_shape.RowSplits(3).Data();

auto final_arcs = Ragged<ArcInfo>(final_arcs_shape);
ArcInfo *final_arcs_data = final_arcs.values.Data();

K2_EVAL(
c_, final_arcs_shape.NumElements(), lambda_populate_final_arcs,
(int32_t idx0123) {
const int32_t idx012 = fas_row_ids3_data[idx0123], // state
idx01 = fas_row_ids2_data[idx012], // context
idx0 = fas_row_ids1_data[idx01], // stream
idx012x = fas_row_splits3_data[idx012],
arc_idx3 = idx0123 - idx012x;

const Arc *graph_arcs_data = graphs_arcs_data[idx0];
const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0];
int32_t graph_state_idx0 = final_graph_states_data[idx012];

int32_t ai_graph_arc_idx01 = 0;
int32_t ai_arc_label = 0;
if (graph_state_idx0 < 0) {
/*
Could be one of following two cases:
case 1: not expandable if graph_state_idx0 == -(num_graph_states_data[idx0]) # noqa
case 2: implicit self-loop if graph_state_idx0 > -(num_graph_states_data[idx0]) # noqa
*/
K2_DCHECK_GT(graph_state_idx0, -(num_graph_states_data[idx0]));
ai_arc_label = 0;
ai_graph_arc_idx01 = -1;
} else {
// For arcs shown in decoding graph.
int32_t graph_arc_idx0x = graph_row_split1_data[graph_state_idx0];
// arc_idx2 could be viewed as graph_arc_idx1,
// since final_arcs_shape has 3 axes where arc_idx2 is calculated,
// while decoding_graph only has 2 axes where arc_idx2 is used.
ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx3;
auto graph_arc = graph_arcs_data[ai_graph_arc_idx01];
ai_arc_label = graph_arc.label;
}
ArcInfo ai;
// ai.dest_state will be overwritted by FormatOutput
// just initialize it as -1 here
ai.dest_state = -1;
ai.graph_arc_idx01 = ai_graph_arc_idx01;
ai.score = 0.0;
ai.label = ai_arc_label;
final_arcs_data[idx0123] = ai;
});

prev_frames_.emplace_back(std::make_shared<Ragged<ArcInfo>>(final_arcs));
}

void RnntDecodingStreams::FormatOutput(const std::vector<int32_t> &num_frames,
bool allow_partial, FsaVec *ofsa,
Array1<int32_t> *out_map) {
Expand All @@ -736,6 +907,8 @@ void RnntDecodingStreams::FormatOutput(const std::vector<int32_t> &num_frames,

GatherPrevFrames(num_frames);

GetFinalArcs();

int32_t frames = prev_frames_.size();
auto last_frame_shape = prev_frames_[frames - 1]->shape;

Expand Down Expand Up @@ -888,7 +1061,7 @@ void RnntDecodingStreams::FormatOutput(const std::vector<int32_t> &num_frames,
K2_EVAL(
c_, num_streams_, lambda_set_start_offset, (int32_t stream_idx) {
num_padded_frames_data[stream_idx] =
frames - num_padded_frames_data[stream_idx];
frames - num_padded_frames_data[stream_idx] - 1;
K2_CHECK_LE(0, num_padded_frames_data[stream_idx]);
});
}
Expand Down Expand Up @@ -946,17 +1119,18 @@ void RnntDecodingStreams::FormatOutput(const std::vector<int32_t> &num_frames,
int32_t dest_state_idx012 = oarc_idx01xx_next + arc_info.dest_state;
arc.dest_state = dest_state_idx012 - oarc_idx0xxx;

// graph_arc_idx01 == -1 means this is a implicit epsilon self-loop
// graph_arc_idx01 < 0 means this is an implicit epsilon self-loop
// arc_info.label == -1 means this is the final arc before last
// frame this is non-accessible arc, we set its label to 0 here to
// make the generated lattice a valid k2 fsa.
if (arc_info.graph_arc_idx01 == -1 || arc_info.label == -1) {
if (arc_info.graph_arc_idx01 <= -1 || arc_info.label == -1) {
arc.label = 0;
out_map_data[oarc_idx01234] = -1;
} else {
arc.label = graph_arcs_data[arc_info.graph_arc_idx01].label;
out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01;
}
arc.score = arc_info.score;
out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01;
}
arcs_out_data[oarc_idx01234] = arc;
if (arc_map_b != nullptr) {
Expand Down
75 changes: 72 additions & 3 deletions k2/csrc/rnnt_decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,25 @@ struct RnntDecodingConfig {

struct ArcInfo {
// The arc-index within the RnntDecodingStream::graph that corresponds to this
// arc, or -1 if this arc is a "termination symbol" (these do not appear in
// the graph).
// arc if non-negative.
// There is an implicit self-loop arc for each state, which are represented
// by -(state_index + 1), see following comments of dest_state_in_graph.
int32_t graph_arc_idx01;

// Note:
// 1. To save memory, value of this variable is calculated
// from graph_arc_idx01.
// 2. It is differnt from variable dest_state.
// dest_state_in_graph is the destination state index in decoding graph.
// dest_state below is the state index in "generated lattice".
// There are two kinds of arcs in decoding graph:
// 1. Implicit self-loop arcs, dest_state of these arcs are calculated
// with -(graph_arc_idx01 + 1).
// (Note, graph_arc_idx01 is negative for these arcs)
// 2. Other arcs shown in decoding graph, dest_state of these arcs are
// calculated with graph_arcs_data[ai.graph_arc_idx01].dest_state
// int32_t dest_state_in_graph;

// The score on the arc; contains both the graph score (if any) and the score
// from the RNN-T joiner.
float score;
Expand Down Expand Up @@ -220,6 +235,38 @@ class RnntDecodingStreams {
void FormatOutput(const std::vector<int32_t> &num_frames, bool allow_partial,
FsaVec *ofsa, Array1<int32_t> *out_map);

/*
Generate the lattice.
Note: Almost the same with previous overloaded version,
except for an extra `is_final` argument.

Note: The prev_frames_ only contains decoded by current object, in order to
generate the lattice we will first gather all the previous frames from
individual streams.

@param [in] num_frames A vector containing the number of frames we want
to gather for each stream (note: the frames we have
ever received).
It MUST satisfy `num_frames.size() == num_streams_`, and
`num_frames[i] <= srcs_[i].prev_frames.size()`.
@param [in] allow_partial If true and there is no final state active,
we will treat all the states on the last frame
to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
@param [in] is_final If true, function GetFinalArcs() will be called.
If false, the same with previous overloaded version.
@param [out] ofsa The output lattice will write to here, its num_axes
equals to 3, will be re-allocated.
@param [out] out_map It is an Array1 with Dim() equals to
ofsa.NumElements() containing the idx01 into the graph of
each individual streams, mapping current arc in ofsa to
original decoding graphs. It may contain -1 which means
this arc is a "termination symbol".
*/
void FormatOutput(const std::vector<int32_t> &num_frames, bool allow_partial,
bool is_final, FsaVec *ofsa, Array1<int32_t> *out_map);

/*
Terminate the decoding process of current RnntDecodingStreams object, it
will update the states & scores of each individual stream and split &
Expand Down Expand Up @@ -282,8 +329,30 @@ class RnntDecodingStreams {

@return Return the renumbering object indicating which arc will be kept.
*/
Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape,
Renumbering DoFirstPassPruning(RaggedShape &unprund_arcs_shape,
const Array2<float> &logprobs);

/*
Get final arcs when last frame is received, i.e. passing is_final=True to
function `FormatOutput`.
Comparing with openfst, a valid fsa in k2 needs arcs with label==-1
pointing to a super final state. This function is handling these arcs.
See detail of the problem solved by this function at
https://github.com/k2-fsa/k2/pull/1089

If we name varialbes for last two steps of a lattice as:
arcs: last frame arcs final arcs
states: {last frame state} ---------------> {final states} ---------> {super final state}

This function mainly do following steps:
1. get last_frame from prev_frames_
2. expand last frame and get final states
3. re-assign dest state of last frame arcs to final states
4. populate final arcs
5. append final arcs to prev_frames_
*/
void GetFinalArcs();

/*
Group states by contexts.

Expand Down