diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7b21b997cd..bf2581217d 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -6,7 +6,7 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 onnxruntime==1.13.1 +pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py @@ -15,7 +15,6 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v pytest -v -s $TE_PATH/tests/pytorch/test_jit.py NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py @@ -24,3 +23,9 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py + +# Build custom ONNX extensions for ONNX export test +pip install onnxruntime==1.19.2 +export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops +bash $CUSTOM_ORT_OPS_PATH/build.sh +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L1_pytorch_context_parallel_test/test.sh b/qa/L1_pytorch_context_parallel_test/test.sh index 7f3c289b36..81ab8ee20b 100644 --- a/qa/L1_pytorch_context_parallel_test/test.sh +++ b/qa/L1_pytorch_context_parallel_test/test.sh @@ -6,5 +6,5 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==7.2.0 onnxruntime==1.13.1 +pip install pytest==7.2.0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/custom_ort_ops/.gitignore b/tests/pytorch/custom_ort_ops/.gitignore new file mode 100644 index 0000000000..d491fb774c --- /dev/null +++ b/tests/pytorch/custom_ort_ops/.gitignore @@ -0,0 +1,3 @@ +build +onnxruntime +libcustom_ort_ops.so diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt new file mode 100644 index 0000000000..90fb3624c1 --- /dev/null +++ b/tests/pytorch/custom_ort_ops/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.21) +project(custom_ort_ops LANGUAGES CXX) + +# Dependencies +find_package(CUDAToolkit REQUIRED) +set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) +if(NOT EXISTS "${ONNX_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find ONNX Runtime headers. " + "Please clone https://github.com/microsoft/onnxruntime " + "into TransformerEngine/tests/pytorch/onnx.") +endif() +include_directories(${ONNX_INCLUDE_DIR}) + +# Configure library +add_library(custom_ort_ops SHARED custom_op_library.cc) +target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) +target_include_directories(custom_ort_ops PUBLIC + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(custom_ort_ops PRIVATE + ${ONNX_INCLUDE_DIR}/onnxruntime + ${ONNX_INCLUDE_DIR}/onnxruntime/core/session) + +# Install library +install(TARGETS custom_ort_ops DESTINATION .) diff --git a/tests/pytorch/custom_ort_ops/README.md b/tests/pytorch/custom_ort_ops/README.md new file mode 100644 index 0000000000..ca392805be --- /dev/null +++ b/tests/pytorch/custom_ort_ops/README.md @@ -0,0 +1,22 @@ +# Custom ONNX Runtime operators for Transformer Engine tests + +This directory contains code that builds custom ONNX operators for use +in Transformer Engine tests. It includes basic, non-performant +implementations of the FP8 quantization and dequantization operators +that are used when exporting Transformer Engine models to ONNX. + +For more information, see [the ONNX Runtime reference for custom +operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). +Much of the code has been adapted from [an ONNX Runtime +test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). + +## Usage + +* Build the custom operators: +```bash +$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh +``` +* Run the ONNX export tests with pytest: +```bash +$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py +``` \ No newline at end of file diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh new file mode 100644 index 0000000000..989da2f4ef --- /dev/null +++ b/tests/pytorch/custom_ort_ops/build.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -ex + +: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} +cd ${CUSTOM_ORT_OPS_PATH} + +# Download ONNX Runtime source +git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true + +# Configure and build with CMake +mkdir -p build +cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. +cmake --build build --verbose +cmake --install build --verbose diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc new file mode 100755 index 0000000000..f46e897152 --- /dev/null +++ b/tests/pytorch/custom_ort_ops/custom_op_library.cc @@ -0,0 +1,102 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "custom_op_library.h" + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_c_api.h" +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/session/onnxruntime_lite_custom_op.h" +#include + +namespace { + +template +void Quantize(OrtKernelContext* context, + const Ort::Custom::Tensor& input, + const Ort::Custom::Tensor& scale_inv, + Ort::Custom::Tensor& output) { + auto raw_input = input.Data(); + auto raw_scale_inv = scale_inv.Data(); + auto raw_output = reinterpret_cast(output.Allocate(input.Shape())); + const auto rs = static_cast(raw_scale_inv[0]); + const size_t N = input.NumberOfElement(); + for (size_t i = 0; i < N; ++i) { + const auto x = static_cast(raw_input[i]); + raw_output[i] = static_cast(x / rs); + } +} + +template +void Dequantize(OrtKernelContext* context, + const Ort::Custom::Tensor& input, + const Ort::Custom::Tensor& scale_inv, + Ort::Custom::Tensor& output) { + auto raw_input = reinterpret_cast(input.Data()); + auto raw_scale_inv = scale_inv.Data(); + auto raw_output = output.Allocate(input.Shape()); + const auto rs = static_cast(raw_scale_inv[0]); + const size_t N = input.NumberOfElement(); + for (size_t i = 0; i < N; ++i) { + const auto x = rs * static_cast(raw_input[i]); + raw_output[i] = static_cast(x); + } +} + +static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { + static std::vector ort_custom_op_domain_container; + static std::mutex ort_custom_op_domain_mutex; + std::lock_guard lock(ort_custom_op_domain_mutex); + ort_custom_op_domain_container.push_back(std::move(domain)); +} + +} // namespace + +OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { + Ort::Global::api_ = api->GetApi(ORT_API_VERSION); + + // Namespace for custom ops + static const char* c_OpDomain = "trt"; + + // Construct custom ops + static const std::unique_ptr c_Quantize{ + Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", + "CPUExecutionProvider", + Quantize) + }; + static const std::unique_ptr c_Dequantize{ + Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", + "CPUExecutionProvider", + Dequantize<__nv_fp8_e4m3, float, float>) + }; + + // Register custom ops + OrtStatus* result = nullptr; + ORT_TRY { + Ort::CustomOpDomain domain{c_OpDomain}; + domain.Add(c_Quantize.get()); + domain.Add(c_Dequantize.get()); + Ort::UnownedSessionOptions session_options(options); + session_options.Add(domain); + AddOrtCustomOpDomainToContainer(std::move(domain)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + Ort::Status status{e}; + result = status.release(); + }); + } + return result; +} diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/tests/pytorch/custom_ort_ops/custom_op_library.h new file mode 100755 index 0000000000..7e4b8256bc --- /dev/null +++ b/tests/pytorch/custom_ort_ops/custom_op_library.h @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once +#include "onnxruntime/core/session/onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); + +#ifdef __cplusplus +} +#endif diff --git a/tests/pytorch/libcustom_ort_fp8_qdq_ops.so b/tests/pytorch/libcustom_ort_fp8_qdq_ops.so deleted file mode 100755 index 61d9232e3a..0000000000 Binary files a/tests/pytorch/libcustom_ort_fp8_qdq_ops.so and /dev/null differ diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index bdc459cdcc..6a463b556a 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -72,7 +72,7 @@ assert OPSET >= TRILU_OPSET # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). -ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") +ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -85,7 +85,7 @@ @pytest.fixture() def seed_default_rng(): """Reseed the PRNG for test reproducibility""" - torch.random.seed() + torch.manual_seed(1234) @pytest.fixture() diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 0fa9401163..9b4b2df145 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -146,89 +146,136 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_gelu""" # pylint: disable=unused-argument # TE computes GELU using float32 precision so wrap the GELU subgraph with # conversion to/from float32. - gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh") if scale: - gelu = quantize(g, gelu, scale, fp8_tensor) - return gelu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_relu""" # pylint: disable=unused-argument - relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) + out = torch.onnx.symbolic_opset9.relu(g, inp) if scale: - relu = quantize(g, relu, scale, fp8_tensor) - return relu + out = quantize(g, out, scale, fp8_tensor) + return out @symbolic_helper.parse_args("v", "i") def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for swiglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - return g.op("Mul", g.op("Sigmoid", first), second) + out = g.op("Mul", g.op("Sigmoid", first), second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_swiglu""" # pylint: disable=unused-argument - swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_swiglu(g, inp, 1) if scale: - swiglu = quantize(g, swiglu, scale, fp8_tensor) - return swiglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "i") def onnx_reglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for reglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - return g.op("Mul", g.op("Relu", first), second) + out = g.op("Mul", g.op("Relu", first), second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_reglu""" # pylint: disable=unused-argument - reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_reglu(g, inp, 1) if scale: - reglu = quantize(g, reglu, scale, fp8_tensor) - return reglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "i") def onnx_geglu(g: jit_utils.GraphContext, inp, dim): """ONNX graph for geglu""" + + # Check dimensions dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) if dim_size is not None: assert dim_size % 2 == 0 + # Perform compute in FP32 + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) first, second = g.op("Split", inp, axis_i=dim, outputs=2) - first_gelu = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") - return g.op("Mul", first_gelu, second) + first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") + out = g.op("Mul", first, second) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): +def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_geglu""" # pylint: disable=unused-argument - geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_geglu(g, inp, 1) if scale: - geglu = quantize(g, geglu, scale, fp8_tensor) - return geglu + out = quantize(g, out, scale, fp8_tensor) + elif dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out @symbolic_helper.parse_args( @@ -394,7 +441,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga @symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8( g, - inputs, + inp, weight, eps, scale, @@ -407,50 +454,54 @@ def onnx_rmsnorm_fwd_fp8( ): """ONNX graph for rmsnorm_fwd_fp8""" # pylint: disable=unused-argument - inp_dtype = get_TensorProtoDataType(inputs) - - if inp_dtype != get_TensorProtoDataType(weight): - weight = g.op("Cast", weight, to_i=inp_dtype) - - ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale, fp8_tensor) - return fp8_ln + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma) + out = quantize(g, out, scale, fp8_tensor) + return out @symbolic_helper.parse_args("v", "v", "f", "i", "b") -def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma): +def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma): """ONNX graph for rmsnorm_fwd""" # pylint: disable=unused-argument - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + # Check dimensions + normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp) if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) + ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp) assert ndim is not None normalized_shape = list(range(0, ndim)) # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 normalized_shape = normalized_shape[1:] + axis = -len(normalized_shape) + + # Cast input tensors to FP32 if needed + dtype = get_TensorProtoDataType(inp) + if dtype != _type_utils.JitScalarType.FLOAT: + inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) + if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT: + weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT) + # Adjust zero-centered weights if zero_centered_gamma: - inputs_dtype = inputs.type().dtype() - one = _ones_like(g, weight, inputs_dtype) + one = _ones_like(g, weight, torch.float32) weight = g.op("Add", weight, one) - axis = -len(normalized_shape) - - inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - sum_square = g.op("ReduceSumSquare", inputs_float, axes_i=[axis]) - shape = g.op("Shape", inputs_float, start_i=-1) + # Perform compute in FP32 + sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis]) + shape = g.op("Shape", inp, start_i=-1) shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) mean_squared = g.op("Div", sum_square, shape_f) eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) rms_squared = g.op("Add", mean_squared, eps_tensor) rms_eps = g.op("Sqrt", rms_squared) - normalized_input = g.op("Div", inputs_float, rms_eps) - result = g.op("Mul", weight, normalized_input) - result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) - - return result + normalized_input = g.op("Div", inp, rms_eps) + out = g.op("Mul", weight, normalized_input) + if dtype != _type_utils.JitScalarType.FLOAT: + out = g.op("Cast", out, to_i=dtype) + return out register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER)