Skip to content

Commit

Permalink
[PyTorch] Build custom ORT ops before running ONNX export tests (#1252)
Browse files Browse the repository at this point in the history
* Build custom ORT ops before running ONNX tests

Signed-off-by: Tim Moon <[email protected]>

* Remove ONNX from context parallelism tests

Signed-off-by: Tim Moon <[email protected]>

* Export ONNX ops that do compute in FP32

Matches internal impl of TE kernels.

Signed-off-by: Tim Moon <[email protected]>

* Add build script for custom ORT ops

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Oct 16, 2024
1 parent 54aa12a commit f6b766b
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 54 deletions.
9 changes: 7 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==8.2.1 onnxruntime==1.13.1
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
Expand All @@ -15,7 +15,6 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
Expand All @@ -24,3 +23,9 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py

# Build custom ONNX extensions for ONNX export test
pip install onnxruntime==1.19.2
export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
bash $CUSTOM_ORT_OPS_PATH/build.sh
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
2 changes: 1 addition & 1 deletion qa/L1_pytorch_context_parallel_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==7.2.0 onnxruntime==1.13.1
pip install pytest==7.2.0
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
3 changes: 3 additions & 0 deletions tests/pytorch/custom_ort_ops/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build
onnxruntime
libcustom_ort_ops.so
29 changes: 29 additions & 0 deletions tests/pytorch/custom_ort_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

cmake_minimum_required(VERSION 3.21)
project(custom_ort_ops LANGUAGES CXX)

# Dependencies
find_package(CUDAToolkit REQUIRED)
set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include)
if(NOT EXISTS "${ONNX_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find ONNX Runtime headers. "
"Please clone https://github.com/microsoft/onnxruntime "
"into TransformerEngine/tests/pytorch/onnx.")
endif()
include_directories(${ONNX_INCLUDE_DIR})

# Configure library
add_library(custom_ort_ops SHARED custom_op_library.cc)
target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart)
target_include_directories(custom_ort_ops PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(custom_ort_ops PRIVATE
${ONNX_INCLUDE_DIR}/onnxruntime
${ONNX_INCLUDE_DIR}/onnxruntime/core/session)

# Install library
install(TARGETS custom_ort_ops DESTINATION .)
22 changes: 22 additions & 0 deletions tests/pytorch/custom_ort_ops/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Custom ONNX Runtime operators for Transformer Engine tests

This directory contains code that builds custom ONNX operators for use
in Transformer Engine tests. It includes basic, non-performant
implementations of the FP8 quantization and dequantization operators
that are used when exporting Transformer Engine models to ONNX.

For more information, see [the ONNX Runtime reference for custom
operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html).
Much of the code has been adapted from [an ONNX Runtime
test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc).

## Usage

* Build the custom operators:
```bash
$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh
```
* Run the ONNX export tests with pytest:
```bash
$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py
```
17 changes: 17 additions & 0 deletions tests/pytorch/custom_ort_ops/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -ex

: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))}
cd ${CUSTOM_ORT_OPS_PATH}

# Download ONNX Runtime source
git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true

# Configure and build with CMake
mkdir -p build
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=.
cmake --build build --verbose
cmake --install build --verbose
102 changes: 102 additions & 0 deletions tests/pytorch/custom_ort_ops/custom_op_library.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "custom_op_library.h"

#define ORT_API_MANUAL_INIT
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT

#include <exception>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>

#include "core/common/common.h"
#include "core/session/onnxruntime_lite_custom_op.h"
#include <cuda_fp8.h>

namespace {

template <typename IType, typename OType, typename CType>
void Quantize(OrtKernelContext* context,
const Ort::Custom::Tensor<IType>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<unsigned char>& output) {
auto raw_input = input.Data();
auto raw_scale_inv = scale_inv.Data();
auto raw_output = reinterpret_cast<OType*>(output.Allocate(input.Shape()));
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x / rs);
}
}

template <typename IType, typename OType, typename CType>
void Dequantize(OrtKernelContext* context,
const Ort::Custom::Tensor<unsigned char>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<OType>& output) {
auto raw_input = reinterpret_cast<const IType*>(input.Data());
auto raw_scale_inv = scale_inv.Data();
auto raw_output = output.Allocate(input.Shape());
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = rs * static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x);
}
}

static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container;
static std::mutex ort_custom_op_domain_mutex;
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
ort_custom_op_domain_container.push_back(std::move(domain));
}

} // namespace

OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION);

// Namespace for custom ops
static const char* c_OpDomain = "trt";

// Construct custom ops
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Quantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear",
"CPUExecutionProvider",
Quantize<float, __nv_fp8_e4m3, float>)
};
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Dequantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear",
"CPUExecutionProvider",
Dequantize<__nv_fp8_e4m3, float, float>)
};

// Register custom ops
OrtStatus* result = nullptr;
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
domain.Add(c_Quantize.get());
domain.Add(c_Dequantize.get());
Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
AddOrtCustomOpDomainToContainer(std::move(domain));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Ort::Status status{e};
result = status.release();
});
}
return result;
}
18 changes: 18 additions & 0 deletions tests/pytorch/custom_ort_ops/custom_op_library.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#pragma once
#include "onnxruntime/core/session/onnxruntime_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif

ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);

#ifdef __cplusplus
}
#endif
Binary file removed tests/pytorch/libcustom_ort_fp8_qdq_ops.so
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
assert OPSET >= TRILU_OPSET

# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so")

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
Expand All @@ -85,7 +85,7 @@
@pytest.fixture()
def seed_default_rng():
"""Reseed the PRNG for test reproducibility"""
torch.random.seed()
torch.manual_seed(1234)


@pytest.fixture()
Expand Down
Loading

0 comments on commit f6b766b

Please sign in to comment.