-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Build custom ORT ops before running ONNX export tests (#1252)
* Build custom ORT ops before running ONNX tests Signed-off-by: Tim Moon <[email protected]> * Remove ONNX from context parallelism tests Signed-off-by: Tim Moon <[email protected]> * Export ONNX ops that do compute in FP32 Matches internal impl of TE kernels. Signed-off-by: Tim Moon <[email protected]> * Add build script for custom ORT ops Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
- Loading branch information
Showing
11 changed files
with
301 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build | ||
onnxruntime | ||
libcustom_ort_ops.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
cmake_minimum_required(VERSION 3.21) | ||
project(custom_ort_ops LANGUAGES CXX) | ||
|
||
# Dependencies | ||
find_package(CUDAToolkit REQUIRED) | ||
set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) | ||
if(NOT EXISTS "${ONNX_INCLUDE_DIR}") | ||
message(FATAL_ERROR | ||
"Could not find ONNX Runtime headers. " | ||
"Please clone https://github.com/microsoft/onnxruntime " | ||
"into TransformerEngine/tests/pytorch/onnx.") | ||
endif() | ||
include_directories(${ONNX_INCLUDE_DIR}) | ||
|
||
# Configure library | ||
add_library(custom_ort_ops SHARED custom_op_library.cc) | ||
target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) | ||
target_include_directories(custom_ort_ops PUBLIC | ||
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) | ||
target_include_directories(custom_ort_ops PRIVATE | ||
${ONNX_INCLUDE_DIR}/onnxruntime | ||
${ONNX_INCLUDE_DIR}/onnxruntime/core/session) | ||
|
||
# Install library | ||
install(TARGETS custom_ort_ops DESTINATION .) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Custom ONNX Runtime operators for Transformer Engine tests | ||
|
||
This directory contains code that builds custom ONNX operators for use | ||
in Transformer Engine tests. It includes basic, non-performant | ||
implementations of the FP8 quantization and dequantization operators | ||
that are used when exporting Transformer Engine models to ONNX. | ||
|
||
For more information, see [the ONNX Runtime reference for custom | ||
operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). | ||
Much of the code has been adapted from [an ONNX Runtime | ||
test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). | ||
|
||
## Usage | ||
|
||
* Build the custom operators: | ||
```bash | ||
$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh | ||
``` | ||
* Run the ONNX export tests with pytest: | ||
```bash | ||
$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
set -ex | ||
|
||
: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} | ||
cd ${CUSTOM_ORT_OPS_PATH} | ||
|
||
# Download ONNX Runtime source | ||
git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true | ||
|
||
# Configure and build with CMake | ||
mkdir -p build | ||
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. | ||
cmake --build build --verbose | ||
cmake --install build --verbose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#include "custom_op_library.h" | ||
|
||
#define ORT_API_MANUAL_INIT | ||
#include "onnxruntime_c_api.h" | ||
#include "onnxruntime_cxx_api.h" | ||
#undef ORT_API_MANUAL_INIT | ||
|
||
#include <exception> | ||
#include <memory> | ||
#include <mutex> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "core/common/common.h" | ||
#include "core/session/onnxruntime_lite_custom_op.h" | ||
#include <cuda_fp8.h> | ||
|
||
namespace { | ||
|
||
template <typename IType, typename OType, typename CType> | ||
void Quantize(OrtKernelContext* context, | ||
const Ort::Custom::Tensor<IType>& input, | ||
const Ort::Custom::Tensor<CType>& scale_inv, | ||
Ort::Custom::Tensor<unsigned char>& output) { | ||
auto raw_input = input.Data(); | ||
auto raw_scale_inv = scale_inv.Data(); | ||
auto raw_output = reinterpret_cast<OType*>(output.Allocate(input.Shape())); | ||
const auto rs = static_cast<CType>(raw_scale_inv[0]); | ||
const size_t N = input.NumberOfElement(); | ||
for (size_t i = 0; i < N; ++i) { | ||
const auto x = static_cast<CType>(raw_input[i]); | ||
raw_output[i] = static_cast<OType>(x / rs); | ||
} | ||
} | ||
|
||
template <typename IType, typename OType, typename CType> | ||
void Dequantize(OrtKernelContext* context, | ||
const Ort::Custom::Tensor<unsigned char>& input, | ||
const Ort::Custom::Tensor<CType>& scale_inv, | ||
Ort::Custom::Tensor<OType>& output) { | ||
auto raw_input = reinterpret_cast<const IType*>(input.Data()); | ||
auto raw_scale_inv = scale_inv.Data(); | ||
auto raw_output = output.Allocate(input.Shape()); | ||
const auto rs = static_cast<CType>(raw_scale_inv[0]); | ||
const size_t N = input.NumberOfElement(); | ||
for (size_t i = 0; i < N; ++i) { | ||
const auto x = rs * static_cast<CType>(raw_input[i]); | ||
raw_output[i] = static_cast<OType>(x); | ||
} | ||
} | ||
|
||
static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { | ||
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container; | ||
static std::mutex ort_custom_op_domain_mutex; | ||
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex); | ||
ort_custom_op_domain_container.push_back(std::move(domain)); | ||
} | ||
|
||
} // namespace | ||
|
||
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { | ||
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION); | ||
|
||
// Namespace for custom ops | ||
static const char* c_OpDomain = "trt"; | ||
|
||
// Construct custom ops | ||
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Quantize{ | ||
Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", | ||
"CPUExecutionProvider", | ||
Quantize<float, __nv_fp8_e4m3, float>) | ||
}; | ||
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Dequantize{ | ||
Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", | ||
"CPUExecutionProvider", | ||
Dequantize<__nv_fp8_e4m3, float, float>) | ||
}; | ||
|
||
// Register custom ops | ||
OrtStatus* result = nullptr; | ||
ORT_TRY { | ||
Ort::CustomOpDomain domain{c_OpDomain}; | ||
domain.Add(c_Quantize.get()); | ||
domain.Add(c_Dequantize.get()); | ||
Ort::UnownedSessionOptions session_options(options); | ||
session_options.Add(domain); | ||
AddOrtCustomOpDomainToContainer(std::move(domain)); | ||
} | ||
ORT_CATCH(const std::exception& e) { | ||
ORT_HANDLE_EXCEPTION([&]() { | ||
Ort::Status status{e}; | ||
result = status.release(); | ||
}); | ||
} | ||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#pragma once | ||
#include "onnxruntime/core/session/onnxruntime_c_api.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.