Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MlasTranspose multi-threads support. #22912

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions onnxruntime/core/framework/transpose_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct has_mlas_transpose<uint32_t> : std::true_type {};
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
int64_t writes_per_writer_per_loop) {
int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) {
ORT_UNUSED_PARAMETER(tp);
const T* end;
for (int64_t l = 0; l < num_loops; ++l) {
T* output_for_first_writer = output_data;
Expand All @@ -48,10 +49,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
int64_t writes_per_writer_per_loop) {
int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data, output_data, static_cast<size_t>(writes_per_writer_per_loop),
static_cast<size_t>(num_writers));
static_cast<size_t>(num_writers), tp);
input_data += writes_per_loop;
output_data += writes_per_loop;
}
Expand Down Expand Up @@ -82,25 +83,25 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
switch (bytes_per_write) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop,
writes_per_writer_per_loop);
writes_per_writer_per_loop, tp);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
writes_per_loop, writes_per_writer_per_loop, tp);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
writes_per_loop, writes_per_writer_per_loop, tp);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
writes_per_loop, writes_per_writer_per_loop, tp);
break;
}
default: {
Expand All @@ -125,7 +126,8 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
int64_t reads_per_reader_per_loop) {
int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) {
ORT_UNUSED_PARAMETER(tp);
T* end;
for (int64_t l = 0; l < num_loops; ++l) {
const T* input_for_first_reader = input_data;
Expand All @@ -150,10 +152,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
int64_t reads_per_reader_per_loop) {
int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data, output_data, static_cast<size_t>(num_readers),
static_cast<size_t>(reads_per_reader_per_loop));
static_cast<size_t>(reads_per_reader_per_loop), tp);
input_data += reads_per_loop;
output_data += reads_per_loop;
}
Expand All @@ -162,7 +164,8 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos
// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits.
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
size_t from, size_t to, const TensorShape* input_shape_override = nullptr,
concurrency::ThreadPool* tp = nullptr) {
ORT_UNUSED_PARAMETER(permutations);

const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
Expand All @@ -184,25 +187,25 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
switch (bytes_per_read) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
reads_per_reader_per_loop, tp);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
reads_per_reader_per_loop, tp);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
reads_per_reader_per_loop, tp);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
reads_per_reader_per_loop, tp);
break;
}
default: {
Expand Down Expand Up @@ -236,7 +239,7 @@ void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& inp
if (from > to) {
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp);
} else {
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override, tp);
}
}

Expand Down Expand Up @@ -309,4 +312,4 @@ bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& f
return single_axis_moved;
}

} // namespace onnxruntime
} // namespace onnxruntime
52 changes: 10 additions & 42 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1030,49 +1030,15 @@ MlasComputeTanh(
// Transpose routines.
//

template<typename DataType>
void
MLASCALL
MlasTranspose(
const uint8_t* Input,
uint8_t* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
const int8_t* Input,
int8_t* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
const uint32_t* Input,
uint32_t* Output,
const DataType* Input,
DataType* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
const float* Input,
float* Output,
size_t M,
size_t N
size_t N,
MLAS_THREADPOOL* ThreadPool
);

//
Expand Down Expand Up @@ -1780,20 +1746,22 @@ MlasConvDepthwise(
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
);


inline
void
MlasTranspose(
const MLAS_FP16* Input,
MLAS_FP16* Output,
size_t M,
size_t N
size_t N,
MLAS_THREADPOOL* ThreadPool
)
{
MlasTranspose(
reinterpret_cast<const uint16_t*>(Input),
reinterpret_cast<uint16_t*>(Output),
M, N);
M,
N,
ThreadPool);
}


Expand Down
Loading
Loading