Skip to content

Commit

Permalink
REVERTME
Browse files Browse the repository at this point in the history
  • Loading branch information
groenenboomj committed Mar 22, 2024
1 parent 9c3f949 commit 312e98a
Showing 1 changed file with 56 additions and 56 deletions.
112 changes: 56 additions & 56 deletions onnxruntime/core/providers/rocm/rnn/miopen_rnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,62 +279,62 @@ Status MiopenRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
workspace_rocm.get(),
workspace_bytes));
}
else {
// miopen doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
// there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
std::vector<int32_t> seq_len_array(sequence_lens_data, sequence_lens_data + batch_size);
for (int i = 0; i < batch_size; ++i) {
if (0 == seq_len_array[i]) {
seq_len_array[i] = 1;
zero_seq_index_cache[zero_seq_count] = i;
++zero_seq_count;
}
}

// Calculate the zero position cache for reverse direction if it's bidirectional
// The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
// we hacked the 0 sequence to 1
if (zero_seq_count && num_directions_ > 1) {
zero_seq_index_cache_size = zero_seq_count * num_directions_;
zero_seq_index_cache.resize(zero_seq_index_cache_size);
for (int64_t i = 0; i < zero_seq_count; ++i) {
zero_seq_index_cache[static_cast<size_t>(zero_seq_count) + i] = static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
}
}

miopenTensorDescriptor_t x_desc1;
ORT_RETURN_IF_ERROR(x_desc1.Set(MiopenTensor::GetDataType<HipT>(), seq_length, batch_size, input_size, seq_len_array.data()));
miopenTensorDescriptor_t y_desc1;
ORT_RETURN_IF_ERROR(y_desc1.Set(MiopenTensor::GetDataType<HipT>(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data()));

MIOPEN_RETURN_IF_ERROR(miopenRNNForwardInference(GetMiopenHandle(ctx),
rnn_desc,
x_desc1,
x_data_input,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc1,
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
workspace_rocm.get(),
workspace_bytes));

// Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
if (nullptr == Y) {
// Mask on output for 0 sequence batches
if (zero_seq_count > 0) {
SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
}
return Status::OK();
}
}
// else {
// // miopen doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
// // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
// std::vector<int32_t> seq_len_array(sequence_lens_data, sequence_lens_data + batch_size);
// for (int i = 0; i < batch_size; ++i) {
// if (0 == seq_len_array[i]) {
// seq_len_array[i] = 1;
// zero_seq_index_cache[zero_seq_count] = i;
// ++zero_seq_count;
// }
// }

// // Calculate the zero position cache for reverse direction if it's bidirectional
// // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
// // we hacked the 0 sequence to 1
// if (zero_seq_count && num_directions_ > 1) {
// zero_seq_index_cache_size = zero_seq_count * num_directions_;
// zero_seq_index_cache.resize(zero_seq_index_cache_size);
// for (int64_t i = 0; i < zero_seq_count; ++i) {
// zero_seq_index_cache[static_cast<size_t>(zero_seq_count) + i] = static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
// }
// }

// miopenTensorDescriptor_t x_desc1;
// MIOPEN_RETURN_IF_ERROR(x_desc1.Set(MiopenTensor::GetDataType<HipT>(), seq_length, batch_size, input_size, seq_len_array.data()));
// miopenTensorDescriptor_t y_desc1;
// MIOPEN_RETURN_IF_ERROR(y_desc1.Set(MiopenTensor::GetDataType<HipT>(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data()));

// MIOPEN_RETURN_IF_ERROR(miopenRNNForwardInference(GetMiopenHandle(ctx),
// rnn_desc,
// x_desc1,
// x_data_input,
// hx_desc,
// hx_data,
// cx_desc,
// cx_data,
// weight_cached_ ? w_desc_cache_ : w_desc,
// weight_cached_ ? w_data_cache_.get() : w_data.get(),
// y_desc1,
// y_data,
// y_h_desc,
// y_h_data,
// y_c_desc,
// y_c_data,
// workspace_rocm.get(),
// workspace_bytes));

// // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
// if (nullptr == Y) {
// // Mask on output for 0 sequence batches
// if (zero_seq_count > 0) {
// SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
// }
// return Status::OK();
// }
// }

IAllocatorUniquePtr<T> y_reorganized_data;
if (reverse_ || num_directions_ == 2) {
Expand Down

0 comments on commit 312e98a

Please sign in to comment.