@@ -56,17 +56,23 @@ class bidirectional_layer : public layer
56
56
forward_state_h_(stateful ? tensor(tensor_shape(n_units), static_cast<float_type>(0 )) : fplus::nothing<tensor>()),
57
57
forward_state_c_(stateful && wrapped_layer_type_has_state_c(wrapped_layer_type) ? tensor(tensor_shape(n_units), static_cast<float_type>(0 )) : fplus::nothing<tensor>()),
58
58
backward_state_h_(stateful ? tensor(tensor_shape(n_units), static_cast<float_type>(0 )) : fplus::nothing<tensor>()),
59
- backward_state_c_(stateful && wrapped_layer_type_has_state_c(wrapped_layer_type) ? tensor(tensor_shape(n_units), static_cast<float_type>(0 )) : fplus::nothing<tensor>())
59
+ backward_state_c_(stateful && wrapped_layer_type_has_state_c(wrapped_layer_type) ? tensor(tensor_shape(n_units), static_cast<float_type>(0 )) : fplus::nothing<tensor>()),
60
+ use_avail_input_state_for_stateful_(true )
61
+
60
62
{
61
63
}
62
64
63
65
void reset_states () override
64
66
{
67
+ // TF 2.1 Bug: reset_states() does nothing in TF 2.1.
68
+ // the implementation below is how TF 2.1 should behave.
69
+ // to match TF 2.1, just comment out the code below.
65
70
if (is_stateful ()) {
66
71
forward_state_h_ = tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
67
72
forward_state_c_ = tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
68
73
backward_state_h_ = tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
69
74
backward_state_c_ = tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
75
+ use_avail_input_state_for_stateful_ = true ;
70
76
}
71
77
}
72
78
@@ -110,29 +116,26 @@ class bidirectional_layer : public layer
110
116
assertion (inputs.size () == 1 || inputs.size () == 5 ,
111
117
" Invalid number of input tensors." );
112
118
113
- tensor forward_state_h = inputs.size () == 5
114
- ? inputs[1 ]
115
- : is_stateful ()
116
- ? forward_state_h_.unsafe_get_just ()
117
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
118
-
119
- tensor forward_state_c = inputs.size () == 5
120
- ? inputs[2 ]
121
- : is_stateful ()
122
- ? forward_state_c_.unsafe_get_just ()
123
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
124
-
125
- tensor backward_state_h = inputs.size () == 5
126
- ? inputs[3 ]
127
- : is_stateful ()
128
- ? backward_state_h_.unsafe_get_just ()
129
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
130
-
131
- tensor backward_state_c = inputs.size () == 5
132
- ? inputs[4 ]
133
- : is_stateful ()
134
- ? backward_state_c_.unsafe_get_just ()
135
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
119
+ bool initial_state_provided = inputs.size () == 5 ;
120
+ bool use_last_state_for_initial_state = is_stateful () && !use_avail_input_state_for_stateful_;
121
+ bool use_input_initial_state = initial_state_provided && !use_last_state_for_initial_state;
122
+ // bool use_zero_initial_state = !use_input_initial_state && !use_last_state_for_initial_state;
123
+
124
+ tensor forward_state_h = use_input_initial_state ? inputs[1 ] :
125
+ use_last_state_for_initial_state ? forward_state_h_.unsafe_get_just () :
126
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
127
+
128
+ tensor forward_state_c = use_input_initial_state ? inputs[2 ] :
129
+ use_last_state_for_initial_state ? forward_state_c_.unsafe_get_just () :
130
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
131
+
132
+ tensor backward_state_h = use_input_initial_state ? inputs[3 ] :
133
+ use_last_state_for_initial_state ? backward_state_h_.unsafe_get_just () :
134
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
135
+
136
+ tensor backward_state_c = use_input_initial_state ? inputs[4 ] :
137
+ use_last_state_for_initial_state ? backward_state_c_.unsafe_get_just () :
138
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
136
139
137
140
result_forward = lstm_impl (input, forward_state_h, forward_state_c,
138
141
n_units_, use_bias_, return_sequences_, stateful_,
@@ -147,24 +150,26 @@ class bidirectional_layer : public layer
147
150
forward_state_c_ = forward_state_c;
148
151
backward_state_h_ = backward_state_h;
149
152
backward_state_c_ = backward_state_c;
153
+ use_avail_input_state_for_stateful_ = false ;
150
154
}
151
155
}
152
156
else if (wrapped_layer_type_ == " GRU" || wrapped_layer_type_ == " CuDNNGRU" )
153
157
{
154
158
assertion (inputs.size () == 1 || inputs.size () == 3 ,
155
159
" Invalid number of input tensors." );
156
160
157
- tensor forward_state_h = inputs.size () == 3
158
- ? inputs[1 ]
159
- : is_stateful ()
160
- ? forward_state_h_.unsafe_get_just ()
161
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
161
+ bool initial_state_provided = inputs.size () == 3 ;
162
+ bool use_last_state_for_initial_state = is_stateful () && !use_avail_input_state_for_stateful_;
163
+ bool use_input_initial_state = initial_state_provided && !use_last_state_for_initial_state;
164
+ // bool use_zero_initial_state = !use_input_initial_state && !use_last_state_for_initial_state;
165
+
166
+ tensor forward_state_h = use_input_initial_state ? inputs[1 ] :
167
+ use_last_state_for_initial_state ? forward_state_h_.unsafe_get_just () :
168
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
162
169
163
- tensor backward_state_h = inputs.size () == 3
164
- ? inputs[2 ]
165
- : is_stateful ()
166
- ? backward_state_h_.unsafe_get_just ()
167
- : tensor (tensor_shape (n_units_), static_cast <float_type>(0 ));
170
+ tensor backward_state_h = use_input_initial_state ? inputs[2 ] :
171
+ use_last_state_for_initial_state ? backward_state_h_.unsafe_get_just () :
172
+ tensor (tensor_shape (n_units_), static_cast <float_type>(0 )); // use_zero_initial_state
168
173
169
174
result_forward = gru_impl (input, forward_state_h, n_units_, use_bias_, reset_after_, return_sequences_, false ,
170
175
forward_weights_, forward_recurrent_weights_,
@@ -175,6 +180,7 @@ class bidirectional_layer : public layer
175
180
if (is_stateful ()) {
176
181
forward_state_h_ = forward_state_h;
177
182
backward_state_h_ = backward_state_h;
183
+ use_avail_input_state_for_stateful_ = false ;
178
184
}
179
185
}
180
186
else
@@ -223,6 +229,7 @@ class bidirectional_layer : public layer
223
229
mutable fplus::maybe<tensor> forward_state_c_;
224
230
mutable fplus::maybe<tensor> backward_state_h_;
225
231
mutable fplus::maybe<tensor> backward_state_c_;
232
+ mutable bool use_avail_input_state_for_stateful_;
226
233
};
227
234
228
235
} // namespace internal
0 commit comments