diff --git a/onnxruntime/core/providers/rocm/rnn/miopen_rnn_base.cc b/onnxruntime/core/providers/rocm/rnn/miopen_rnn_base.cc index 27dd24d85f7d3..ea0bd7796b7c9 100644 --- a/onnxruntime/core/providers/rocm/rnn/miopen_rnn_base.cc +++ b/onnxruntime/core/providers/rocm/rnn/miopen_rnn_base.cc @@ -279,62 +279,62 @@ Status MiopenRnnBase::ComputeInternal(OpKernelContext* ctx) const { workspace_rocm.get(), workspace_bytes)); } - else { - // miopen doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } - - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } - - miopenTensorDescriptor_t x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(MiopenTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - miopenTensorDescriptor_t y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(MiopenTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - MIOPEN_RETURN_IF_ERROR(miopenRNNForwardInference(GetMiopenHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_rocm.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { - // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); - } - } + // else { + // // miopen doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + // std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); + // for (int i = 0; i < batch_size; ++i) { + // if (0 == seq_len_array[i]) { + // seq_len_array[i] = 1; + // zero_seq_index_cache[zero_seq_count] = i; + // ++zero_seq_count; + // } + // } + + // // Calculate the zero position cache for reverse direction if it's bidirectional + // // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // // we hacked the 0 sequence to 1 + // if (zero_seq_count && num_directions_ > 1) { + // zero_seq_index_cache_size = zero_seq_count * num_directions_; + // zero_seq_index_cache.resize(zero_seq_index_cache_size); + // for (int64_t i = 0; i < zero_seq_count; ++i) { + // zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); + // } + // } + + // miopenTensorDescriptor_t x_desc1; + // MIOPEN_RETURN_IF_ERROR(x_desc1.Set(MiopenTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); + // miopenTensorDescriptor_t y_desc1; + // MIOPEN_RETURN_IF_ERROR(y_desc1.Set(MiopenTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); + + // MIOPEN_RETURN_IF_ERROR(miopenRNNForwardInference(GetMiopenHandle(ctx), + // rnn_desc, + // x_desc1, + // x_data_input, + // hx_desc, + // hx_data, + // cx_desc, + // cx_data, + // weight_cached_ ? w_desc_cache_ : w_desc, + // weight_cached_ ? w_data_cache_.get() : w_data.get(), + // y_desc1, + // y_data, + // y_h_desc, + // y_h_data, + // y_c_desc, + // y_c_data, + // workspace_rocm.get(), + // workspace_bytes)); + + // // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. + // if (nullptr == Y) { + // // Mask on output for 0 sequence batches + // if (zero_seq_count > 0) { + // SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + // } + // return Status::OK(); + // } + // } IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) {