Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdaily committed Oct 15, 2024
1 parent 37da60a commit 16c1b49
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 53 deletions.
38 changes: 19 additions & 19 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,13 +623,13 @@ template Status ReduceComputeCore<float, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
Stream* ort_stream,
const TensorShape* input_shape_override);

//template Status ReduceComputeCore<double, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op,
// gsl::span<const int64_t> axes,
// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
// Stream* ort_stream,
// const TensorShape* input_shape_override);
// template Status ReduceComputeCore<double, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op,
// gsl::span<const int64_t> axes,
// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
// Stream* ort_stream,
// const TensorShape* input_shape_override);

template Status ReduceComputeCore<MLFloat16, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
Expand Down Expand Up @@ -833,71 +833,71 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N
// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
//REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
//REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int32_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int64_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int8_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, uint8_t, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, int32_t, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int32_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int64_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int8_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, uint8_t, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, int32_t, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, MLFloat16, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, float, 12, 13)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, double, 12, 13)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, double, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int32_t, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int64_t, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, BFloat16, 12, 13)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, BFloat16, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, BFloat16, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, BFloat16, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, int32_t, 17, 18)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, float, 17, 18)
//REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, double, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, int32_t, 17, 18)

Expand Down
64 changes: 32 additions & 32 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1308,39 +1308,39 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint32_t, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint64_t, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, bool, Cast);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast);
//class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast);
// #endif

class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, DequantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear);
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear);
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear);
// #endif
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, DequantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear);
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear);
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear);
// #endif

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Loop);
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, QuantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear);
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear);
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear);
// #endif
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, QuantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear);
//class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear);
// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear);
// #endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape);
Expand All @@ -1362,12 +1362,12 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, DequantizeLinear);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, DequantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear);
// #endif

class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, QuantizeLinear);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, QuantizeLinear);
Expand All @@ -1377,12 +1377,12 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, QuantizeLinear);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, QuantizeLinear);
//#if !defined(DISABLE_FLOAT8_TYPES)
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear);
//class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear);
//#endif
// #if !defined(DISABLE_FLOAT8_TYPES)
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear);
// #endif

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/rocm_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once
#include "core/providers/rocm/rocm_pch.h"
//#include "core/providers/rocm/shared_inc/rocm_utils.h"
// #include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include "core/framework/stream_handles.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/tunable/rocm_tunable.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ using TunableOp = TunableOp<ParamsT, Timer>;
// As a convenience for authoring TunableOp in contrib namespace
namespace contrib {
namespace rocm {
using onnxruntime::rocm::tunable::RocmTuningContext;
using onnxruntime::rocm::tunable::Op;
using onnxruntime::rocm::tunable::OpParams;
using onnxruntime::rocm::tunable::RocmTuningContext;
using onnxruntime::rocm::tunable::TunableOp;
} // namespace rocm
} // namespace contrib
Expand Down

0 comments on commit 16c1b49

Please sign in to comment.