From c3442f9f787fc8cc3925a09420d1e60aa4ffb234 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 6 Mar 2024 10:45:41 -0800 Subject: [PATCH] [AOT Compile] E2E compile->execute->test framework using TorchDynamo export, TCP compilation, CPU codegen (#46) ### [Large PR] For ease of review, I've stacked the changes into 3 commits: - https://github.com/cruise-automation/mlir-tcp/pull/46/commits/ada575353cd2381504447c93bc204773134a9fc2: *[Please review]* AOT compile flow for TD export + TCP compilation and CPU codegen - https://github.com/cruise-automation/mlir-tcp/pull/46/commits/75031b6aad77aa11a53909293c22b00a24d74cd0: *[Please skim]* AOT README (also pasted as PR description below) - https://github.com/cruise-automation/mlir-tcp/pull/46/commits/705aeba4b65db3a676667d87642acd910c8488f1: *[Please skip]* NFC / mechanical changes Inlining `tools/aot/README.md` below for readability. ------- AOT Compile (Developer Guide) ============================= The [`aot_compile`](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl) bazel macro implements an end-to-end framework to compile PyTorch (or TCP) programs to a CPU library, execute it and test for functional correctness of the generated code. It comprises starting with TorchDynamo export of PyTorch programs, conversion and lowerings through {Torch, TCP, Linalg, LLVM} MLIR dialects, translation to LLVM assembly, compilation to assembly source for the host architecture (CPU), and lastly generation of shared object that can be dynamically linked into an executable at runtime. It leverages a series of genrules to stitch the compilation pipeline together, and an unsophisticated meta-programming trick for auto-generating C++ tests (specialized to the input program's function signature) that execute the compiled code and validate its numerics against reference PyTorch. When authoring new TCP ops with dialect conversions from/to Torch and Linalg, adding an `aot_compile` target is a fast, automated and standardized way to test the e2e compilation and validate that the op lowerings are implemented consistent with PyTorch semantics. ## Compile PyTorch programs Onboarding to the `aot_compile` macro is quite easy (examples [here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/BUILD)). Start by adding the following line to the `BUILD` to load the macro: ```starlark load("//tools/aot:aot_compile.bzl", "aot_compile") ``` Then call the macro like this: ```starlark aot_compile( name = "broadcast_add_mixed_ranks", torch_loader_lib = ":add_mul_loader_lib", torch_loader_path = "test.AotCompile.add_mul_loader_lib.broadcast_add_mixed_ranks_loader", ) ``` Here, `torch_loader_lib` expects a `py_library` target for the module that defines the PyTorch program to be AOT compiled, and `torch_loader_path` is the full python import path (dot separated) to the loader function. ```starlark py_library( name = "add_mul_loader_lib", srcs = ["add_mul_loader_lib.py"], visibility = ["//visibility:public"], deps = [ requirement("torch"), "//tools/aot:torch_loader_utils", ], ) ``` The loader function can be called anything really, but it should define the PyTorch program, sample inputs and dynamic dim constraints (if any), and always return a `TorchLoaderOutput` object. The PyTorch program's forward function must always consume and return tensors, like so: ```python import torch from torch.export import dynamic_dim from tools.aot.torch_loader_utils import TorchLoaderOutput def broadcast_add_mixed_ranks_loader() -> TorchLoaderOutput: class BroadcastAddMixedRanks(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: add = torch.add(x, y) return add # Sample inputs x = torch.tensor(10.0) y = torch.randn(2) # Dynamic dim constraints constraints = [dynamic_dim(y, 0)] return TorchLoaderOutput( model=BroadcastAddMixedRanks(), inputs=[x, y], constraints=constraints, ) ``` An invocation of `aot_compile(name="foo", ...)` generates a bunch of targets (see [here](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl#L43) for the list) that can be helpful in debugging the intermediate steps in the compilation process. To get the full list of `aot_compile` macro generated targets for `broadcast_add_mixed_ranks`, run the query: ```shell $ bazel query 'attr(name, "broadcast_add_mixed_ranks", //test/AotCompile/...)' //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test //test/AotCompile:broadcast_add_mixed_ranks_execute_test_generator //test/AotCompile:broadcast_add_mixed_ranks_torch_exporter //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors ``` Lets walk through a series of steps involved in debugging an e2e compilation pipeline. Note that these steps are not required to be manually run one at a time (although they can be). Bazel automatically identifies the DAG of dependencies and executes just what is needed to build the specified target. #### 1. Inspect the Torch dialect (`*_torch.mlir`) exported from the PyTorch program: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch (61 packages loaded, 16582 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir INFO: Elapsed time: 6.085s, Critical Path: 0.69s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```ll $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir module { func.func @func_main(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32> } } ``` #### 2. Inspect the TCP dialect (`*_tcp.mlir`) lowered from the Torch dialect: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp (0 packages loaded, 0 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir INFO: Elapsed time: 0.572s, Critical Path: 0.03s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```ll $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir module { func.func @func_main(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %expanded = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> %dim = tensor.dim %arg1, %c0 : tensor %0 = tcp.broadcast %expanded, %dim {axes = [0]} : tensor<1xf32>, index -> tensor %1 = tcp.add %0, %arg1 : tensor, tensor -> tensor return %1 : tensor } } ``` #### 3. Inspect the LLVM dialect (`*_llvm.mlir`) lowered from the TCP dialect: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm (0 packages loaded, 0 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir INFO: Elapsed time: 0.305s, Critical Path: 0.00s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```ll $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir module { llvm.func @malloc(i64) -> !llvm.ptr llvm.func @func_main(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: i64, %arg6: i64, %arg7: i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64)> %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64)> %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64)> . . . %57 = llvm.load %56 : !llvm.ptr -> f32 %58 = llvm.fadd %54, %57 : f32 %59 = llvm.getelementptr %44[%51] : (!llvm.ptr, i64) -> !llvm.ptr, f32 llvm.store %58, %59 : f32, !llvm.ptr %60 = llvm.add %51, %14 : i64 llvm.br ^bb4(%60 : i64) ^bb6: // pred: ^bb4 llvm.return %50 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> } } ``` #### 4. Inspect the LLVM assembly (`*.ll`) translated from the LLVM dialect: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir (0 packages loaded, 0 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll INFO: Elapsed time: 0.312s, Critical Path: 0.00s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```ll $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll ; ModuleID = 'LLVMDialectModule' source_filename = "LLVMDialectModule" declare ptr @malloc(i64) define { ptr, ptr, i64, [1 x i64], [1 x i64] } @func_main(ptr %0, ptr %1, i64 %2, ptr %3, ptr %4, i64 %5, i64 %6, i64 %7) { %9 = insertvalue { ptr, ptr, i64 } undef, ptr %0, 0 %10 = insertvalue { ptr, ptr, i64 } %9, ptr %1, 1 %11 = insertvalue { ptr, ptr, i64 } %10, i64 %2, 2 %12 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %3, 0 . . . 63: ; preds = %51 ret { ptr, ptr, i64, [1 x i64], [1 x i64] } %50 } !llvm.module.flags = !{!0} !0 = !{i32 2, !"Debug Info Version", i32 3} ``` #### 5. Inspect the assembly source (`*.S`) compiled for the host architecture (CPU): ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm (0 packages loaded, 0 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S INFO: Elapsed time: 0.360s, Critical Path: 0.03s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```ll $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S .text .file "LLVMDialectModule" .globl func_main # -- Begin function func_main .p2align 4, 0x90 .type func_main,@function func_main: # @func_main .cfi_startproc # %bb.0: pushq %rbp .cfi_def_cfa_offset 16 . . . popq %r14 .cfi_def_cfa_offset 24 popq %r15 .cfi_def_cfa_offset 16 popq %rbp .cfi_def_cfa_offset 8 retq .Lfunc_end0: .size func_main, .Lfunc_end0-func_main .cfi_endproc # -- End function .section ".note.GNU-stack","",@progbits ``` #### 6. Build the shared object (`*.so`) from the host assembly that can be dynamically linked into an executable at runtime: ```shell $ bazel build //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks INFO: Analyzed target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks (8 packages loaded, 8403 targets configured). INFO: Found 1 target... Target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks up-to-date: bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.a bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.so INFO: Elapsed time: 2.264s, Critical Path: 0.12s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` #### 7. Save the reference input and output tensors needed for validation of the compiled code: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors (0 packages loaded, 5 targets configured). INFO: Found 1 target... Target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz INFO: Elapsed time: 0.743s, Critical Path: 0.15s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` #### 8. Inspect the C++ test (`*_execute_test.cpp`) auto-generated from the template: ```shell $ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test (22 packages loaded, 91 targets configured). INFO: Found 1 target...Target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test up-to-date: bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp INFO: Elapsed time: 0.329s, Critical Path: 0.02s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action ``` ```cpp $ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp //===------------------------------------------------------------*- C++ -*-===// // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "tools/aot/abi.h" #include "cnpy.h" #include "gtest/gtest.h" using namespace mlir::tcp; #pragma clang diagnostic ignored "-Wreturn-type-c-linkage" template static StridedMemRefType CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { StridedMemRefType Result; Result.basePtr = arr.data(); Result.data = arr.data(); Result.offset = 0; // Check if the Rank matches if (arr.shape.size() != Rank) { std::cerr << "Error: Rank mismatch." << std::endl; // Return an uninitialized memref return Result; } // Check if the DataType matches if (arr.word_size != sizeof(DataType)) { std::cerr << "Error: Data type mismatch." << std::endl; // Return an uninitialized memref return Result; } // Set sizes and strides based on the shape of the numpy array int stride = 1; for (int i = Rank - 1; i >= 0; --i) { Result.sizes[i] = arr.shape[i]; Result.strides[i] = stride; stride *= arr.shape[i]; } return Result; } // CreateMemRefFromNpyArray function specialized for rank 0 template static StridedMemRefType CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { StridedMemRefType Result; Result.basePtr = arr.data(); Result.data = arr.data(); Result.offset = 0; // Check if the Rank matches if (!arr.shape.empty()) { std::cerr << "Error: Rank mismatch. Expected rank-0 array." << std::endl; // Return an uninitialized memref return Result; } // Check if the DataType matches if (arr.word_size != sizeof(DataType)) { std::cerr << "Error: Data type mismatch." << std::endl; // Return an uninitialized memref return Result; } return Result; } // ### DO NOT MODIFY ### // // This template file is pre-processed by `aot_compile` bazel macro // to materialize the templated parameters based on the inputs // passed by the callsite where the macro is instantiated. struct OutputMemRefDescriptor { StridedMemRefType Output0; }; extern "C" OutputMemRefDescriptor func_main( DECL_RANK_0_MEMREF_ABI(float), DECL_RANK_1_MEMREF_ABI(float) ); TEST(AotCompiled, ExecuteTest) { cnpy::npz_t reference_tensors = cnpy::npz_load( "test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz" ); cnpy::NpyArray refInput0 = reference_tensors["Input0"]; cnpy::NpyArray refInput1 = reference_tensors["Input1"]; cnpy::NpyArray refOutput0 = reference_tensors["Output0"]; StridedMemRefType Input0 = CreateMemRefFromNpyArray(refInput0); StridedMemRefType Input1 = CreateMemRefFromNpyArray(refInput1); OutputMemRefDescriptor Result = func_main( PASS_RANK_0_MEMREF(Input0), PASS_RANK_1_MEMREF(Input1) ); ASSERT_EQ(Result.Output0.sizes[0], refOutput0.shape[0]); for (int i = 0; i < refOutput0.num_vals; i++) EXPECT_EQ(Result.Output0.data[i], refOutput0.data()[i]); free(Result.Output0.basePtr); } ``` #### 9. Run the C++ test to execute the generated code and validate functional correctness against reference PyTorch ```shell $ bazel run //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test INFO: Analyzed target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test (0 packages loaded, 0 targets configured). INFO: Found 1 target... Target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test up-to-date: bazel-bin/test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test INFO: Elapsed time: 0.215s, Critical Path: 0.00s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action INFO: Running command line: external/bazel_tools/tools/test/test-setup.sh test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test exec ${PAGER:-/usr/bin/less} "$0" || exit 1 Executing tests from //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test ----------------------------------------------------------------------------- Running main() from gmock_main.cc [==========] Running 1 test from 1 test suite. [----------] Global test environment set-up. [----------] 1 test from AotCompiled [ RUN ] AotCompiled.ExecuteTest [ OK ] AotCompiled.ExecuteTest (1 ms) [----------] 1 test from AotCompiled (1 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (1 ms total) [ PASSED ] 1 test. ``` ## Compile TCP programs The `aot_compile` macro also accepts TCP dialect programs as inputs (instead of PyTorch programs). This is useful to maintain framework neutrality by allowing alternate ingress pathways (like Stablehlo, JAX, TensorFlow, ONNX etc.) into the TCP dialect. When `tcp_source` is specified, the generated `aot_compiled_foo` CPU library has one global function for every function in the TCP program. Let's look at an example. ```starlark aot_compile( name = "basic_tcp_ops", tcp_source = "basic_tcp_ops.mlir", ) ``` Here, `tcp_source` expects a `.mlir` file containing TCP programs, like so: ```ll # basic_tcp_ops.mlir func.func @func_1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = tcp.add %arg0, %arg1 : tensor, tensor -> tensor %1 = tcp.mul %0, %arg2 : tensor, tensor -> tensor return %1 : tensor } func.func @func_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { %0 = tcp.add %arg0, %arg1 : tensor, tensor -> tensor %1 = tcp.mul %0, %arg2 : tensor, tensor -> tensor return %0, %1 : tensor, tensor } func.func @func_3(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg1, %c0 : tensor %arg0_ex = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> %arg0_bcast = tcp.broadcast %arg0_ex, %dim {axes = [0]} : tensor<1xf32>, index -> tensor %0 = tcp.add %arg0_bcast, %arg1 : tensor, tensor -> tensor return %0 : tensor } ``` Now run the query to get all the relevant targets created. ```shell $ bazel query 'attr(name, "basic_tcp_ops", //test/AotCompile/...)' //test/AotCompile:aot_compiled_basic_tcp_ops //test/AotCompile:gen_basic_tcp_ops_host_asm //test/AotCompile:gen_basic_tcp_ops_llvm_ir //test/AotCompile:gen_basic_tcp_ops_mlir_llvm ``` Note we're missing the `//test/AotCompile:basic_tcp_ops_compile_execute_test` target. As there is no access to PyTorch reference implementation, the `aot_compile` macro does not auto-generate C++ execute tests but they can be manually written (example [here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp)). These tests should include `extern "C"` function declarations with the same name and for every function in the input TCP source. The rest of the steps to debug the e2e compilation pipeline are pretty much the same. --- .bazelignore | 5 +- .github/workflows/bazelBuildAndTestTcp.yml | 2 +- .gitignore | 6 +- README.md | 2 +- deps.bzl | 10 +- requirements.txt | 1 + requirements_lock.txt | 16 +- test/AotCompile/BUILD | 34 +- test/AotCompile/add_mul_loader_lib.py | 100 ++++ ...=> basic_tcp_ops_compile_execute_test.cpp} | 14 +- test/BUILD | 2 +- test/lit.cfg.py | 20 +- test/python/BUILD | 2 +- test/python/fx_import/basic_test.py | 3 +- test/python_lit/fx_import/basic_test.py | 2 +- third_party/BUILD | 4 + third_party/cnpy.BUILD | 30 ++ tools/aot/BUILD | 32 ++ tools/aot/README.md | 505 ++++++++++++++++++ tools/aot/abi.h | 15 +- tools/aot/aot_compile.bzl | 249 ++++++++- tools/aot/execute_test.template.cpp | 134 +++++ tools/aot/execute_test_generator.py | 138 +++++ tools/aot/torch_exporter_harness.py | 81 +++ tools/aot/torch_loader_utils.py | 15 + tools/clangd/BUILD | 4 +- 26 files changed, 1370 insertions(+), 56 deletions(-) create mode 100644 test/AotCompile/add_mul_loader_lib.py rename test/AotCompile/{test_aot_compiled_basic_tcp_ops.cpp => basic_tcp_ops_compile_execute_test.cpp} (91%) create mode 100644 third_party/BUILD create mode 100644 third_party/cnpy.BUILD create mode 100644 tools/aot/README.md create mode 100644 tools/aot/execute_test.template.cpp create mode 100644 tools/aot/execute_test_generator.py create mode 100644 tools/aot/torch_exporter_harness.py create mode 100644 tools/aot/torch_loader_utils.py diff --git a/.bazelignore b/.bazelignore index 6e1e2bd..dac859a 100644 --- a/.bazelignore +++ b/.bazelignore @@ -3,4 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -third_party/ +# ignore local_repos of llvm-project, torch-mlir, stablehlo +third_party/llvm-project +third_party/torch-mlir +third_party/stablehlo diff --git a/.github/workflows/bazelBuildAndTestTcp.yml b/.github/workflows/bazelBuildAndTestTcp.yml index 2fcd718..c21602e 100644 --- a/.github/workflows/bazelBuildAndTestTcp.yml +++ b/.github/workflows/bazelBuildAndTestTcp.yml @@ -55,7 +55,7 @@ jobs: mlir-tcp:ci \ find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i if [ -n "$(git status --porcelain)" ]; then - echo "Please run clang-format 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes." + echo "Please run 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes." exit 1 fi diff --git a/.gitignore b/.gitignore index 8bf3e09..e57e392 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,11 @@ bazel-bin bazel-out bazel-mlir-tcp bazel-testlogs -third_party/ + +# ignore local_repos of llvm, torch-mlir, stablehlo +third_party/llvm-project +third_party/torch-mlir +third_party/stablehlo # clangd related .cache diff --git a/README.md b/README.md index 6da0f52..f57a5c6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ bazel build //:tcp-opt bazel test //... ``` -We welcome contributions to `mlir-tcp`. If you do contribute, please finalize your PR with clang-format and bazel buildifier to ensure the C++ sources and BUILD files are formatted consistently: +We welcome contributions to `mlir-tcp`. When authoring new TCP ops with dialect conversions from/to Torch and Linalg, please include lit tests for dialect and conversions, as well as [aot_compile](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/README.md) generated e2e integration tests. Finally, please finalize your PR with clang-format and bazel buildifier to ensure the C++ sources and BUILD files are formatted consistently: ```shell # clang-format find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i diff --git a/deps.bzl b/deps.bzl index 747cd39..be25de3 100644 --- a/deps.bzl +++ b/deps.bzl @@ -43,8 +43,8 @@ def third_party_deps(): TORCH_MLIR_SHA256 = "205ffab6683d5bcbe9bff6afca5aa547826990b0d9d7d58644f9777c37558fd1" http_archive( name = "torch-mlir-raw", - sha256 = TORCH_MLIR_SHA256, build_file_content = "# empty", + sha256 = TORCH_MLIR_SHA256, strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT, urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)], ) @@ -151,3 +151,11 @@ def third_party_deps(): strip_prefix = "bazel-compile-commands-extractor-6d58fa6bf39f612304e55566fa628fd160b38177", url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/6d58fa6bf39f612304e55566fa628fd160b38177.tar.gz", ) + + http_archive( + name = "cnpy", + build_file = "//third_party:cnpy.BUILD", + sha256 = "5120abc54a564efa92c642cc0199cc4fd3f345901157de9fbbdcedbb34d28d8a", + strip_prefix = "cnpy-4e8810b1a8637695171ed346ce68f6984e585ef4", + urls = ["https://github.com/rogersce/cnpy/archive/4e8810b1a8637695171ed346ce68f6984e585ef4.tar.gz"], + ) diff --git a/requirements.txt b/requirements.txt index bc18c5e..c432e29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels torch torch-mlir +numpy \ No newline at end of file diff --git a/requirements_lock.txt b/requirements_lock.txt index 415bd00..79199f9 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -127,7 +127,9 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via torch-mlir + # via + # -r requirements.txt + # torch-mlir packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 @@ -142,11 +144,11 @@ torch==2.3.0.dev20240220+cpu \ # via # -r requirements.txt # torch-mlir -torch-mlir==20240223.16 \ - --hash=sha256:429bce23c3830485b2c35ae48c1308ad66f63d3e86e4a9e19a17b11b58a1f06d \ - --hash=sha256:8b463ad944951781a3c2c51f4240bfc15eb459a6565762499a279d332cc611cf +torch-mlir==20240306.29 \ + --hash=sha256:c505cb254196f694ba1447af0ba3300d514fafe5520c2b3266d5d9c2e9ee4f93 \ + --hash=sha256:da91a833acfba6e80def65295a805156af1399447a141c46f1becdf5f7ae13a3 # via -r requirements.txt -typing-extensions==4.9.0 \ - --hash=sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783 \ - --hash=sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd +typing-extensions==4.10.0 \ + --hash=sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475 \ + --hash=sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb # via torch diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index 1540b5d..864e1c2 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -5,6 +5,36 @@ load("//tools/aot:aot_compile.bzl", "aot_compile") load("@rules_cc//cc:defs.bzl", "cc_test") +load("@rules_python//python:defs.bzl", "py_library") +load("@pip_deps//:requirements.bzl", "requirement") + +py_library( + name = "add_mul_loader_lib", + srcs = ["add_mul_loader_lib.py"], + visibility = ["//visibility:public"], + deps = [ + requirement("torch"), + "//tools/aot:torch_loader_utils", + ], +) + +aot_compile( + name = "add_mul_single_output", + torch_loader_lib = ":add_mul_loader_lib", + torch_loader_path = "test.AotCompile.add_mul_loader_lib.add_mul_single_output_loader", +) + +aot_compile( + name = "add_mul_multi_output", + torch_loader_lib = ":add_mul_loader_lib", + torch_loader_path = "test.AotCompile.add_mul_loader_lib.add_mul_multi_output_loader", +) + +aot_compile( + name = "broadcast_add_mixed_ranks", + torch_loader_lib = ":add_mul_loader_lib", + torch_loader_path = "test.AotCompile.add_mul_loader_lib.broadcast_add_mixed_ranks_loader", +) aot_compile( name = "basic_tcp_ops", @@ -12,8 +42,8 @@ aot_compile( ) cc_test( - name = "test_aot_compiled_basic_tcp_ops", - srcs = ["test_aot_compiled_basic_tcp_ops.cpp"], + name = "basic_tcp_ops_compile_execute_test", + srcs = ["basic_tcp_ops_compile_execute_test.cpp"], tags = ["aot_tests"], deps = [ ":aot_compiled_basic_tcp_ops", diff --git a/test/AotCompile/add_mul_loader_lib.py b/test/AotCompile/add_mul_loader_lib.py new file mode 100644 index 0000000..ae94c42 --- /dev/null +++ b/test/AotCompile/add_mul_loader_lib.py @@ -0,0 +1,100 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch +from torch.export import dynamic_dim + +from tools.aot.torch_loader_utils import TorchLoaderOutput + + +def add_mul_single_output_loader() -> TorchLoaderOutput: + class AddMulNetSingleOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + add = torch.add(x, y) + mul = torch.mul(add, z) + return mul + + # Sample inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + + # Dynamic dim constraints + constraints = [ + # Dim 1 + dynamic_dim(x, 0) == dynamic_dim(y, 0), + dynamic_dim(y, 0) == dynamic_dim(z, 0), + # Dim 2 + dynamic_dim(x, 1) == dynamic_dim(y, 1), + dynamic_dim(y, 1) == dynamic_dim(z, 1), + ] + + return TorchLoaderOutput( + model=AddMulNetSingleOutput(), + inputs=[x, y, z], + constraints=constraints, + ) + + +def add_mul_multi_output_loader() -> TorchLoaderOutput: + class AddMulNetMultiOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> tuple[torch.Tensor]: + add = torch.add(x, y) + mul = torch.mul(add, z) + return add, mul + + # Sample inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + + # Dynamic dim constraints + constraints = [ + # Dim 1 + dynamic_dim(x, 0) == dynamic_dim(y, 0), + dynamic_dim(y, 0) == dynamic_dim(z, 0), + # Dim 2 + dynamic_dim(x, 1) == dynamic_dim(y, 1), + dynamic_dim(y, 1) == dynamic_dim(z, 1), + ] + + return TorchLoaderOutput( + model=AddMulNetMultiOutput(), + inputs=[x, y, z], + constraints=constraints, + ) + + +def broadcast_add_mixed_ranks_loader() -> TorchLoaderOutput: + class BroadcastAddMixedRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + add = torch.add(x, y) + return add + + # Sample inputs + x = torch.tensor(10.0) + y = torch.randn(2) + + # Dynamic dim constraints + constraints = [dynamic_dim(y, 0)] + + return TorchLoaderOutput( + model=BroadcastAddMixedRanks(), + inputs=[x, y], + constraints=constraints, + ) diff --git a/test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp b/test/AotCompile/basic_tcp_ops_compile_execute_test.cpp similarity index 91% rename from test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp rename to test/AotCompile/basic_tcp_ops_compile_execute_test.cpp index bb3c756..fe24c84 100644 --- a/test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp +++ b/test/AotCompile/basic_tcp_ops_compile_execute_test.cpp @@ -61,7 +61,7 @@ TEST(AotCompiled, SingleOutput) { for (int i = 0; i < 2; i++) for (int j = 0; j < 3; j++) { float Expected = (5 + (2 + i)) * (3 + j); - EXPECT_EQ(Result.data[3 * i + j], Expected); + EXPECT_FLOAT_EQ(Result.data[3 * i + j], Expected); } free(Result.basePtr); @@ -109,10 +109,10 @@ TEST(AotCompiled, MultiOutput) { for (int i = 0; i < 2; i++) for (int j = 0; j < 3; j++) { float ExpectedA = 5 + (2 + i); - EXPECT_EQ(Result.A.data[3 * i + j], ExpectedA); + EXPECT_FLOAT_EQ(Result.A.data[3 * i + j], ExpectedA); float ExpectedB = ExpectedA * (3 + j); - EXPECT_EQ(Result.B.data[3 * i + j], ExpectedB); + EXPECT_FLOAT_EQ(Result.B.data[3 * i + j], ExpectedB); } free(Result.A.basePtr); @@ -129,10 +129,10 @@ TEST(AotCompiled, MixedRanks) { StridedMemRefType Result = func_3(&Arr0, &Arr0, 0, Arr1, Arr1, 0, 2, 1); - EXPECT_EQ(Result.sizes[0], 2); - EXPECT_EQ(Result.strides[0], 1); - EXPECT_EQ(Result.data[0], 11.0); - EXPECT_EQ(Result.data[1], 12.0); + ASSERT_EQ(Result.sizes[0], 2); + ASSERT_EQ(Result.strides[0], 1); + EXPECT_FLOAT_EQ(Result.data[0], 11.0); + EXPECT_FLOAT_EQ(Result.data[1], 12.0); free(Result.basePtr); } diff --git a/test/BUILD b/test/BUILD index 876de93..da80bfa 100644 --- a/test/BUILD +++ b/test/BUILD @@ -46,7 +46,7 @@ filegroup( tags = ["lit_tests"], deps = [ requirement("torch"), - requirement("torch_mlir"), + requirement("torch-mlir"), ], ) for src in glob([ diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 9838ab5..453485f 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -12,24 +12,24 @@ # Populate Lit configuration with the minimal required metadata. # Some metadata is populated in lit.site.cfg.py.in. -config.name = 'MLIR_TCP_TESTS_SUITE' +config.name = "MLIR_TCP_TESTS_SUITE" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) -config.suffixes = ['.mlir', '.py'] +config.suffixes = [".mlir", ".py"] tool_dirs = [ - config.llvm_tools_dir, - config.tcp_tools_dir, + config.llvm_tools_dir, + config.tcp_tools_dir, ] # Make LLVM, TCP and PYTHON tools available in RUN directives tools = [ - 'tcp-opt', - 'FileCheck', - 'count', - 'not', - ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), + "tcp-opt", + "FileCheck", + "count", + "not", + ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"), ] llvm_config.add_tool_substitutions(tools, tool_dirs) -llvm_config.with_environment('PYTHONPATH', config.python_path, append_path=True) +llvm_config.with_environment("PYTHONPATH", config.python_path, append_path=True) diff --git a/test/python/BUILD b/test/python/BUILD index 8d4e439..fa8aef4 100644 --- a/test/python/BUILD +++ b/test/python/BUILD @@ -12,7 +12,7 @@ py_test( tags = ["python_tests"], deps = [ requirement("torch"), - requirement("torch_mlir"), + requirement("torch-mlir"), ], ) diff --git a/test/python/fx_import/basic_test.py b/test/python/fx_import/basic_test.py index 0a82dfe..b222c6a 100644 --- a/test/python/fx_import/basic_test.py +++ b/test/python/fx_import/basic_test.py @@ -1,4 +1,4 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. @@ -11,6 +11,7 @@ from torch_mlir import fx + def test_import_frozen_exported_program(): @torch._dynamo.assume_constant_result def get_a(): diff --git a/test/python_lit/fx_import/basic_test.py b/test/python_lit/fx_import/basic_test.py index 36c5548..4ba5524 100644 --- a/test/python_lit/fx_import/basic_test.py +++ b/test/python_lit/fx_import/basic_test.py @@ -1,4 +1,4 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 0000000..6155f07 --- /dev/null +++ b/third_party/BUILD @@ -0,0 +1,4 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. diff --git a/third_party/cnpy.BUILD b/third_party/cnpy.BUILD new file mode 100644 index 0000000..d917f54 --- /dev/null +++ b/third_party/cnpy.BUILD @@ -0,0 +1,30 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +licenses(["notice"]) # MIT + +package( + default_visibility = [ + "//visibility:public", + ], +) + +cc_library( + name = "cnpy", + srcs = ["cnpy.cpp"], + hdrs = ["cnpy.h"], + copts = [ + "-Wno-unused-variable", + ], + deps = ["@llvm_zlib//:zlib"], +) + +cc_test( + name = "test_cnpy", + srcs = ["example1.cpp"], + deps = [":cnpy"], +) diff --git a/tools/aot/BUILD b/tools/aot/BUILD index 29e470c..6e7d467 100644 --- a/tools/aot/BUILD +++ b/tools/aot/BUILD @@ -3,13 +3,45 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_python//python:defs.bzl", "py_library") +load("@pip_deps//:requirements.bzl", "requirement") + package( default_visibility = [ "//visibility:public", ], ) +# Used by `aot_compile` bazel macro +exports_files([ + "torch_exporter_harness.py", + "execute_test_generator.py", + "execute_test.template.cpp", +]) + +py_library( + name = "torch_loader_utils", + srcs = ["torch_loader_utils.py"], + deps = [requirement("torch")], +) + cc_library( name = "abi", hdrs = ["abi.h"], ) + +# Dummy target for clangd compilation database purposes only. +# This specific target is not used by the `aot_compile` bazel +# macro, but an equivalent target is. +cc_test( + name = "execute_test_template", + srcs = ["execute_test.template.cpp"], + tags = ["manual"], + deps = [ + ":abi", + "@cnpy", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:mlir_c_runner_utils_hdrs", + ], +) diff --git a/tools/aot/README.md b/tools/aot/README.md new file mode 100644 index 0000000..a5bf890 --- /dev/null +++ b/tools/aot/README.md @@ -0,0 +1,505 @@ +AOT Compile (Developer Guide) +============================= + +The [`aot_compile`](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl) bazel macro implements an end-to-end framework to compile PyTorch (or TCP) programs to a CPU library, execute it and test for functional correctness of the generated code. It comprises starting with TorchDynamo export of PyTorch programs, conversion and lowerings through {Torch, TCP, Linalg, LLVM} MLIR dialects, translation to LLVM assembly, compilation to assembly source for the host architecture (CPU), and lastly generation of shared object that could be dynamically linked into an executable/test at runtime. It leverages a series of genrules to stitch the compilation pipeline together, and an unsophisticated meta-programming trick for auto-generating C++ tests (specialized to the input program's function signature) that execute the compiled code and validate its numerics against reference PyTorch. + +When authoring new TCP ops with dialect conversions from/to Torch and Linalg, adding an `aot_compile` target is a fast, automated and standardized way to test the e2e compilation and validate that the op lowerings are implemented consistent with PyTorch semantics. + +Caveat: The AOT compile framework's primary objective is to serve as an end-to-end `compile -> execute -> test` harness for functional correctness, and *not* as an optimizing compiler for production usecases. In the future we might be interested in reusing pieces of infrastructure here to construct an optimizing compiler, but it entails more work to get there (such as a runtime and performance benchmark apparatus). + +## Compile PyTorch programs + +Onboarding to the `aot_compile` macro is quite easy (examples [here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/BUILD)). Start by adding the following line to the `BUILD` to load the macro: +```starlark +load("//tools/aot:aot_compile.bzl", "aot_compile") +``` + +Then call the macro like this: +```starlark +aot_compile( + name = "broadcast_add_mixed_ranks", + torch_loader_lib = ":add_mul_loader_lib", + torch_loader_path = "test.AotCompile.add_mul_loader_lib.broadcast_add_mixed_ranks_loader", +) +``` + +Here, `torch_loader_lib` expects a `py_library` target for the module that defines the PyTorch program to be AOT compiled, and `torch_loader_path` is the full python import path (dot separated) to the loader function. +```starlark +py_library( + name = "add_mul_loader_lib", + srcs = ["add_mul_loader_lib.py"], + visibility = ["//visibility:public"], + deps = [ + requirement("torch"), + "//tools/aot:torch_loader_utils", + ], +) +``` + +The loader function can be called anything really, but it should define the PyTorch program, sample inputs and dynamic dim constraints (if any), and always return a `TorchLoaderOutput` object. The PyTorch program's forward function must always consume and return tensors, like so: +```python +import torch +from torch.export import dynamic_dim + +from tools.aot.torch_loader_utils import TorchLoaderOutput + + +def broadcast_add_mixed_ranks_loader() -> TorchLoaderOutput: + class BroadcastAddMixedRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + add = torch.add(x, y) + return add + + # Sample inputs + x = torch.tensor(10.0) + y = torch.randn(2) + + # Dynamic dim constraints + constraints = [dynamic_dim(y, 0)] + + return TorchLoaderOutput( + model=BroadcastAddMixedRanks(), + inputs=[x, y], + constraints=constraints, + ) +``` + +An invocation of `aot_compile(name="foo", ...)` generates a bunch of targets (see [here](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl#L43) for the list) that can be helpful in debugging the intermediate steps in the compilation process. + +To get the full list of `aot_compile` macro generated targets for `broadcast_add_mixed_ranks`, run the query: +```shell +$ bazel query 'attr(name, "broadcast_add_mixed_ranks", //test/AotCompile/...)' + +//test/AotCompile:aot_compiled_broadcast_add_mixed_ranks +//test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test +//test/AotCompile:broadcast_add_mixed_ranks_execute_test_generator +//test/AotCompile:broadcast_add_mixed_ranks_torch_exporter +//test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test +//test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm +//test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir +//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm +//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp +//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch +//test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors +``` + +### Debugging e2e compilation pipeline + +Lets walk through a series of steps involved in debugging an e2e compilation pipeline. Note that these steps are not required to be manually run one at a time (although they can be). Bazel automatically identifies the DAG of dependencies and executes just what is needed to build the specified target. + +#### 1. Inspect the Torch dialect (`*_torch.mlir`) exported from the PyTorch program: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch (61 packages loaded, 16582 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir +INFO: Elapsed time: 6.085s, Critical Path: 0.69s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```ll +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir + +module { + func.func @func_main(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> + } +} +``` + +#### 2. Inspect the TCP dialect (`*_tcp.mlir`) lowered from the Torch dialect: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp (0 packages loaded, 0 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir +INFO: Elapsed time: 0.572s, Critical Path: 0.03s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```ll +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir + +module { + func.func @func_main(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %expanded = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %dim = tensor.dim %arg1, %c0 : tensor + %0 = tcp.broadcast %expanded, %dim {axes = [0]} : tensor<1xf32>, index -> tensor + %1 = tcp.add %0, %arg1 : tensor, tensor -> tensor + return %1 : tensor + } +} +``` + +#### 3. Inspect the LLVM dialect (`*_llvm.mlir`) lowered from the TCP dialect: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm (0 packages loaded, 0 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir +INFO: Elapsed time: 0.305s, Critical Path: 0.00s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```ll +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir + +module { + llvm.func @malloc(i64) -> !llvm.ptr + llvm.func @func_main(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: i64, %arg6: i64, %arg7: i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> { + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64)> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64)> +. +. +. + %57 = llvm.load %56 : !llvm.ptr -> f32 + %58 = llvm.fadd %54, %57 : f32 + %59 = llvm.getelementptr %44[%51] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %58, %59 : f32, !llvm.ptr + %60 = llvm.add %51, %14 : i64 + llvm.br ^bb4(%60 : i64) + ^bb6: // pred: ^bb4 + llvm.return %50 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + } +} +``` + +#### 4. Inspect the LLVM assembly (`*.ll`) translated from the LLVM dialect: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir (0 packages loaded, 0 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll +INFO: Elapsed time: 0.312s, Critical Path: 0.00s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```ll +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll + +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" + +declare ptr @malloc(i64) + +define { ptr, ptr, i64, [1 x i64], [1 x i64] } @func_main(ptr %0, ptr %1, i64 %2, ptr %3, ptr %4, i64 %5, i64 %6, i64 %7) { + %9 = insertvalue { ptr, ptr, i64 } undef, ptr %0, 0 + %10 = insertvalue { ptr, ptr, i64 } %9, ptr %1, 1 + %11 = insertvalue { ptr, ptr, i64 } %10, i64 %2, 2 + %12 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %3, 0 +. +. +. +63: ; preds = %51 + ret { ptr, ptr, i64, [1 x i64], [1 x i64] } %50 +} + +!llvm.module.flags = !{!0} + +!0 = !{i32 2, !"Debug Info Version", i32 3} +``` + +#### 5. Inspect the assembly source (`*.S`) compiled for the host architecture (CPU): +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm (0 packages loaded, 0 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S +INFO: Elapsed time: 0.360s, Critical Path: 0.03s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```ll +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S + + .text + .file "LLVMDialectModule" + .globl func_main # -- Begin function func_main + .p2align 4, 0x90 + .type func_main,@function +func_main: # @func_main + .cfi_startproc +# %bb.0: + pushq %rbp + .cfi_def_cfa_offset 16 +. +. +. + popq %r14 + .cfi_def_cfa_offset 24 + popq %r15 + .cfi_def_cfa_offset 16 + popq %rbp + .cfi_def_cfa_offset 8 + retq +.Lfunc_end0: + .size func_main, .Lfunc_end0-func_main + .cfi_endproc + # -- End function + .section ".note.GNU-stack","",@progbits +``` + +#### 6. Build the shared object (`*.so`) from the host assembly that can be dynamically linked into an executable/test at runtime: +```shell +$ bazel build //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks + +INFO: Analyzed target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks (8 packages loaded, 8403 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks up-to-date: + bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.a + bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.so +INFO: Elapsed time: 2.264s, Critical Path: 0.12s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` + +Note that this `cc_library` target (called `aot_compiled_*`) is marked `testonly`, which only allows it to be added as a dependency for test targets. The goal is to avoid any inadvertent use of the compiled artifacts in production usecases. + +#### 7. Save the reference input and output tensors needed for validation of the compiled code: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors (0 packages loaded, 5 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz +INFO: Elapsed time: 0.743s, Critical Path: 0.15s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` + +#### 8. Inspect the C++ test (`*_execute_test.cpp`) auto-generated from the template: +```shell +$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test + +INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test (22 packages loaded, 91 targets configured). +INFO: Found 1 target...Target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test up-to-date: + bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp +INFO: Elapsed time: 0.329s, Critical Path: 0.02s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +``` +```cpp +$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp + +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "tools/aot/abi.h" + +#include "cnpy.h" +#include "gtest/gtest.h" + +using namespace mlir::tcp; + +#pragma clang diagnostic ignored "-Wreturn-type-c-linkage" + +template +static StridedMemRefType +CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { + StridedMemRefType Result; + Result.basePtr = arr.data(); + Result.data = arr.data(); + Result.offset = 0; + + // Check if the Rank matches + if (arr.shape.size() != Rank) { + std::cerr << "Error: Rank mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Check if the DataType matches + if (arr.word_size != sizeof(DataType)) { + std::cerr << "Error: Data type mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Set sizes and strides based on the shape of the numpy array + int stride = 1; + for (int i = Rank - 1; i >= 0; --i) { + Result.sizes[i] = arr.shape[i]; + Result.strides[i] = stride; + stride *= arr.shape[i]; + } + + return Result; +} + +// CreateMemRefFromNpyArray function specialized for rank 0 +template +static StridedMemRefType +CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { + StridedMemRefType Result; + Result.basePtr = arr.data(); + Result.data = arr.data(); + Result.offset = 0; + + // Check if the Rank matches + if (!arr.shape.empty()) { + std::cerr << "Error: Rank mismatch. Expected rank-0 array." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Check if the DataType matches + if (arr.word_size != sizeof(DataType)) { + std::cerr << "Error: Data type mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + return Result; +} + +// ### DO NOT MODIFY ### // +// This template file is pre-processed by `aot_compile` bazel macro +// to materialize the templated parameters based on the inputs +// passed by the callsite where the macro is instantiated. + +struct OutputMemRefDescriptor { + StridedMemRefType Output0; +}; + +extern "C" OutputMemRefDescriptor func_main( + DECL_RANK_0_MEMREF_ABI(float), + DECL_RANK_1_MEMREF_ABI(float) +); + +TEST(AotCompiled, ExecuteTest) { + cnpy::npz_t reference_tensors = cnpy::npz_load( + "test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz" + ); + + cnpy::NpyArray refInput0 = reference_tensors["Input0"]; + cnpy::NpyArray refInput1 = reference_tensors["Input1"]; + cnpy::NpyArray refOutput0 = reference_tensors["Output0"]; + + StridedMemRefType Input0 = + CreateMemRefFromNpyArray(refInput0); + StridedMemRefType Input1 = + CreateMemRefFromNpyArray(refInput1); + + OutputMemRefDescriptor Result = func_main( + PASS_RANK_0_MEMREF(Input0), + PASS_RANK_1_MEMREF(Input1) + ); + + ASSERT_EQ(Result.Output0.sizes[0], refOutput0.shape[0]); + + for (int i = 0; i < refOutput0.num_vals; i++) + EXPECT_EQ(Result.Output0.data[i], refOutput0.data()[i]); + + free(Result.Output0.basePtr); +} +``` + +#### 9. Run the C++ test to execute the generated code and validate functional correctness against reference PyTorch +```shell +$ bazel run //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test + +INFO: Analyzed target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test (0 packages loaded, 0 targets configured). +INFO: Found 1 target... +Target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test up-to-date: + bazel-bin/test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test +INFO: Elapsed time: 0.215s, Critical Path: 0.00s +INFO: 1 process: 1 internal. +INFO: Build completed successfully, 1 total action +INFO: Running command line: external/bazel_tools/tools/test/test-setup.sh test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test +exec ${PAGER:-/usr/bin/less} "$0" || exit 1 +Executing tests from //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test +----------------------------------------------------------------------------- +Running main() from gmock_main.cc +[==========] Running 1 test from 1 test suite. +[----------] Global test environment set-up. +[----------] 1 test from AotCompiled +[ RUN ] AotCompiled.ExecuteTest +[ OK ] AotCompiled.ExecuteTest (1 ms) +[----------] 1 test from AotCompiled (1 ms total) + +[----------] Global test environment tear-down +[==========] 1 test from 1 test suite ran. (1 ms total) +[ PASSED ] 1 test. +``` + +## Compile TCP programs + +The `aot_compile` macro also accepts TCP dialect programs as inputs (instead of PyTorch programs). This is useful to maintain framework neutrality by allowing alternate ingress pathways (like Stablehlo, JAX, TensorFlow, ONNX etc.) into the TCP dialect. When `tcp_source` is specified, the generated `aot_compiled_foo` CPU library has one global function for every function in the TCP program. Let's look at an example. + +```starlark +aot_compile( + name = "basic_tcp_ops", + tcp_source = "basic_tcp_ops.mlir", +) +``` + +Here, `tcp_source` expects a `.mlir` file containing TCP programs, like so: +```mlir +// basic_tcp_ops.mlir + +func.func @func_1(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = tcp.add %arg0, %arg1 : tensor, tensor -> tensor + %1 = tcp.mul %0, %arg2 : tensor, tensor -> tensor + return %1 : tensor +} + +func.func @func_2(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> (tensor, tensor) { + %0 = tcp.add %arg0, %arg1 : tensor, tensor -> tensor + %1 = tcp.mul %0, %arg2 : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} + +func.func @func_3(%arg0: tensor, + %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg1, %c0 : tensor + %arg0_ex = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %arg0_bcast = tcp.broadcast %arg0_ex, %dim {axes = [0]} : tensor<1xf32>, index -> tensor + %0 = tcp.add %arg0_bcast, %arg1 : tensor, tensor -> tensor + return %0 : tensor +} +``` + +Now run the query to get all the relevant targets created. +```shell +$ bazel query 'attr(name, "basic_tcp_ops", //test/AotCompile/...)' + +//test/AotCompile:aot_compiled_basic_tcp_ops +//test/AotCompile:gen_basic_tcp_ops_host_asm +//test/AotCompile:gen_basic_tcp_ops_llvm_ir +//test/AotCompile:gen_basic_tcp_ops_mlir_llvm +``` + +Note we're missing the `//test/AotCompile:basic_tcp_ops_compile_execute_test` target. As there is no access to PyTorch reference implementation, the `aot_compile` macro does not auto-generate C++ execute tests but they can be manually written (example [here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp)). These tests should include `extern "C"` function declarations with the same name and for every function in the input TCP source. + +The rest of the steps to debug the e2e compilation pipeline are pretty much the same. diff --git a/tools/aot/abi.h b/tools/aot/abi.h index c35de26..624fd99 100644 --- a/tools/aot/abi.h +++ b/tools/aot/abi.h @@ -28,10 +28,19 @@ using IndexTy = long; // StridedMemRefType B; // }; // -// StridedMemRefType is described at +// and StridedMemRefType is defined as: // https://mlir.llvm.org/docs/TargetLLVMIR/#ranked-memref-types -// and defined at // https://sourcegraph.com/github.com/llvm/llvm-project@b5048700fc31f3bf6dd32ace7730815d4cfef411/-/blob/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h?L131 +// +// template +// struct StridedMemRefType { +// T *basePtr; +// T *data; +// int64_t offset; +// int64_t sizes[N]; +// int64_t strides[N]; +// ... +// }; #define DECL_RANK_2_MEMREF_ABI(data_type) \ data_type *, data_type *, IndexTy, IndexTy, IndexTy, IndexTy, IndexTy @@ -47,7 +56,7 @@ using IndexTy = long; (memref).sizes[1], (memref).strides[0], (memref).strides[1] #define PASS_RANK_1_MEMREF(memref) \ (memref).basePtr, (memref).data, (memref).offset, (memref).sizes[0], \ - (memref).strides[0], + (memref).strides[0] #define PASS_RANK_0_MEMREF(memref) \ (memref).basePtr, (memref).data, (memref).offset diff --git a/tools/aot/aot_compile.bzl b/tools/aot/aot_compile.bzl index 4db8ebf..0f21b38 100644 --- a/tools/aot/aot_compile.bzl +++ b/tools/aot/aot_compile.bzl @@ -3,29 +3,246 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -def aot_compile(name, tcp_source): +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") +load("@rules_python//python:defs.bzl", "py_binary") +load("@pip_deps//:requirements.bzl", "requirement") + +def aot_compile( + name, + tcp_source = None, + torch_loader_lib = None, + torch_loader_path = "", + skip_ci = False): """ - AOT compiles `tcp_source` to a CPU library. + AOT compile Torch or TCP programs to a CPU library and execute it to + validate functional correctness of the compiled code against PyTorch + semantics. Exposes a target named `aot_compiled_${name}` that has one global function - for every function in `tcp_source`. Each of the functions in `tcp_source` - must consume and return tensors. The ABI of the generated code is exposed - in abi.h. + for every function in the TCP program (when `tcp_source` is specified), or + one global function corresponding to the PyTorch program's forward function + (when `torch_loader_lib` is specified). + + The functions in Torch or TCP sources must always consume and return tensors. + The ABI of the generated code is exposed in `abi.h`. + + Parameters + ---------- + name + Name of the program to be AOT compiled. + tcp_source + Path to the "*.mlir" source containing the TCP program. + torch_loader_lib + Label of the `py_library` target for the torch_loader module containing + the PyTorch program. + torch_loader_path + Full python import path (dot separated) to the torch_loader function. + skip_ci + When `True`, skip execute tests from CI (and `bazel test //...` expansions). + + Generated Targets + ----------------- + An invocation of `aot_compile(name="foo", ...)` generates the following targets: + aot_compiled_foo: + cc_library wrapper around the AOT compiled assembly source targeting CPU. + This has one global function for every function in the TCP program + (when `tcp_source` is specified), or one global function corresponding to + the PyTorch program's forward function (when `torch_loader_lib` is specified). + When built, generates a shared object that can by dynamically linked into + an executable at runtime. + foo_compile_execute_test: + cc_test that executes the compiled code on CPU using reference inputs, + and validates the outputs against PyTorch. + foo_torch_exporter: + py_binary that runs the torch_loader function to get the `TorchLoaderOutput` + (containing the PyTorch program and inputs), then calls the upstream + `fx.export_and_import` API to generate Torch dialect, and finally runs the + PyTorch program on reference inputs and saves the reference outputs (as .npz) + which will eventually be used for validation of the AOT compiled code. + foo_execute_test_generator: + py_binary that reads the reference tensors to infer the function signature + (rank, element type for each input/output tensor) and then materializes the + templatized parameters in `execute_test.template.cpp`. + gen_foo_mlir_torch: + genrule that invokes `foo_torch_exporter` and saves the torch dialect program + (*_torch.mlir). + gen_foo_mlir_tcp: + genrule that invokes `tcp-opt` to convert the torch dialect program to the + tcp dialect program (*_tcp.mlir) using `-torch-backend-to-tcp-backend-pipeline`. + gen_foo_mlir_llvm: + genrule that invokes `tcp-opt` to convert the tcp dialect program to the + llvm dialect program (*_llvm.mlir) using `-tcp-to-llvm-pipeline`. + gen_foo_llvm_ir: + genrule that invokes `mlir-translate` to convert the llvm dialect program to + the llvm assembly (*.ll) using `-mlir-to-llvmir`. + gen_foo_host_asm: + genrule that invokes `llc` on the llvm assembly to generate assembly source + (*.S) for the host architecture (CPU). + gen_foo_reference_tensors: + genrule that invokes `foo_torch_exporter` and saves the reference tensors + to a numpy archive (*.npz). + gen_foo_execute_test: + genrule that invokes `foo_execute_test_generator` to generate a materialized + execute_test.cpp for foo. + + The set of auto-generated targets can be obtained by running the following query: + bazel query 'attr(name, "foo", //test/AotCompile/...)' + """ + if not tcp_source and not torch_loader_lib: + fail("aot_compile macro requires either `tcp_source` or `torch_loader_lib` " + + "to be specified.") + if tcp_source and torch_loader_lib: + fail("aot_compile macro cannot accept both `tcp_source` and `torch_loader_lib`. " + + "Please specify either one.") + if torch_loader_lib != None and torch_loader_path == "": + fail("aot_compile macro requires `torch_loader_path` to be specified along with " + + "`torch_loader_lib`.") + if tcp_source and torch_loader_path != "": + fail("aot_compile macro cannot accept `torch_loader_path` when `tcp_source` " + + "is specified.") + + _name = "_internal_" + name + + # Use torch_export based compilation if tcp_source is not specified + if not tcp_source: + torch_exporter = name + "_torch_exporter" + reference_tensors_file = _name + "_reference_tensors.npz" + + py_binary( + name = torch_exporter, + srcs = ["//tools/aot:torch_exporter_harness.py"], + main = "torch_exporter_harness.py", + deps = [ + torch_loader_lib, + requirement("numpy"), + requirement("torch"), + requirement("torch-mlir"), + "//tools/aot:torch_loader_utils", + ], + # This is needed for testing the binary standalone + args = ["--torch_loader_path=" + torch_loader_path], + ) + + native.genrule( + name = "gen_" + name + "_reference_tensors", + srcs = [], + outs = [reference_tensors_file], + cmd = "./$(location " + torch_exporter + ")" + + " --torch_loader_path=" + torch_loader_path + + " --reference_tensors_path=$(location " + reference_tensors_file + ")", + tools = [torch_exporter], + ) + + native.genrule( + name = "gen_" + name + "_mlir_torch", + srcs = [], + outs = [_name + "_torch.mlir"], + cmd = "./$(location " + torch_exporter + ")" + + " --torch_loader_path=" + torch_loader_path + + " > $(OUTS)", + tools = [torch_exporter], + ) + + native.genrule( + name = "gen_" + name + "_mlir_tcp", + srcs = [_name + "_torch.mlir"], + outs = [_name + "_tcp.mlir"], + cmd = "./$(location //:tcp-opt)" + + " -torch-backend-to-tcp-backend-pipeline $(SRCS)" + + " > $(OUTS)", + tools = ["//:tcp-opt"], + ) + + native.genrule( + name = "gen_" + name + "_mlir_llvm", + # When tcp_source is provided, prefer that as the start for aot_compile; + # else continue using genrule generated *_tcp.mlir (torch_export workflow) + srcs = [tcp_source or (_name + "_tcp.mlir")], + outs = [_name + "_llvm.mlir"], + cmd = "./$(location //:tcp-opt)" + + " -tcp-to-llvm-pipeline $(SRCS)" + + " > $(OUTS)", + tools = ["//:tcp-opt"], + ) + native.genrule( - name = "_internal_gen_asm_" + name, - srcs = [tcp_source], - outs = ["_internal_" + name + ".S"], - cmd = "./$(location //:tcp-opt) -tcp-to-llvm-pipeline $(SRCS) | ./$(location @llvm-project//mlir:mlir-translate) -mlir-to-llvmir | ./$(location @llvm-project//llvm:llc) -O3 > \"$@\"", - tools = [ - "//:tcp-opt", - "@llvm-project//mlir:mlir-translate", - "@llvm-project//llvm:llc", - ], + name = "gen_" + name + "_llvm_ir", + srcs = [_name + "_llvm.mlir"], + outs = [_name + ".ll"], + cmd = "./$(location @llvm-project//mlir:mlir-translate)" + + " -mlir-to-llvmir $(SRCS)" + + " > $(OUTS)", + tools = ["@llvm-project//mlir:mlir-translate"], ) - native.cc_library( + # TODO: Replace llc with clang for optimized `.o` generation + native.genrule( + name = "gen_" + name + "_host_asm", + srcs = [_name + ".ll"], + outs = [_name + ".S"], + cmd = "./$(location @llvm-project//llvm:llc) -O3 < $(SRCS)" + + " > $(OUTS)", + tools = ["@llvm-project//llvm:llc"], + ) + + cc_library( name = "aot_compiled_" + name, - srcs = ["_internal_" + name + ".S"], + srcs = [_name + ".S"], + # Can only be consumed (depended on) by test targets. + # Prevents inadvertent use in a production usecase. testonly = True, ) + + # Can't use auto-generated tests for tcp_source based compilations due to + # lack of reference inputs/outputs for comparisons; write tests manually. + if not tcp_source: + execute_test_generator = name + "_execute_test_generator" + test_template_file = "//tools/aot:execute_test.template.cpp" + + py_binary( + name = execute_test_generator, + srcs = ["//tools/aot:execute_test_generator.py"], + main = "execute_test_generator.py", + deps = [requirement("numpy")], + # This is needed for testing the binary standalone + args = [ + "--test_template_path=$(location " + test_template_file + ")", + "--reference_tensors_path=$(location " + reference_tensors_file + ")", + ], + data = [ + test_template_file, + reference_tensors_file, + ], + ) + + native.genrule( + name = "gen_" + name + "_execute_test", + srcs = [ + test_template_file, + reference_tensors_file, + ], + outs = [_name + "_execute_test.cpp"], + cmd = "./$(location " + execute_test_generator + ")" + + " --test_template_path=$(location " + test_template_file + ")" + + " --reference_tensors_path=$(location " + reference_tensors_file + ")" + + " > $(OUTS)", + tools = [execute_test_generator], + ) + + cc_test( + name = name + "_compile_execute_test", + srcs = [_name + "_execute_test.cpp"], + tags = [ + "aot_tests", + "manual" if skip_ci else "", + ], + deps = [ + ":aot_compiled_" + name, + "//tools/aot:abi", + "@cnpy//:cnpy", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:mlir_c_runner_utils_hdrs", + ], + data = [reference_tensors_file], + ) diff --git a/tools/aot/execute_test.template.cpp b/tools/aot/execute_test.template.cpp new file mode 100644 index 0000000..19a9c43 --- /dev/null +++ b/tools/aot/execute_test.template.cpp @@ -0,0 +1,134 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "tools/aot/abi.h" + +#include "cnpy.h" +#include "gtest/gtest.h" + +using namespace mlir::tcp; + +#pragma clang diagnostic ignored "-Wreturn-type-c-linkage" + +template +static StridedMemRefType +CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { + StridedMemRefType Result; + Result.basePtr = arr.data(); + Result.data = arr.data(); + Result.offset = 0; + + // Check if the Rank matches + if (arr.shape.size() != Rank) { + std::cerr << "Error: Rank mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Check if the DataType matches + if (arr.word_size != sizeof(DataType)) { + std::cerr << "Error: Data type mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Set sizes and strides based on the shape of the numpy array + int stride = 1; + for (int i = Rank - 1; i >= 0; --i) { + Result.sizes[i] = arr.shape[i]; + Result.strides[i] = stride; + stride *= arr.shape[i]; + } + + return Result; +} + +// CreateMemRefFromNpyArray function specialized for rank 0 +template +static StridedMemRefType +CreateMemRefFromNpyArray(cnpy::NpyArray &arr) { + StridedMemRefType Result; + Result.basePtr = arr.data(); + Result.data = arr.data(); + Result.offset = 0; + + // Check if the Rank matches + if (!arr.shape.empty()) { + std::cerr << "Error: Rank mismatch. Expected rank-0 array." << std::endl; + // Return an uninitialized memref + return Result; + } + + // Check if the DataType matches + if (arr.word_size != sizeof(DataType)) { + std::cerr << "Error: Data type mismatch." << std::endl; + // Return an uninitialized memref + return Result; + } + + return Result; +} + +// ### DO NOT MODIFY ### // +// This template file is pre-processed by `aot_compile` bazel macro +// to materialize the templated parameters based on the inputs +// passed by the callsite where the macro is instantiated. + +struct OutputMemRefDescriptor { + //##OUTPUT_MEMREF_VARIABLE_DECLARATIONS##// + // StridedMemRefType Output0; +}; + +extern "C" OutputMemRefDescriptor func_main( + //##INPUT_MEMREF_ABI_DECLARATIONS##// + // DECL_RANK_2_MEMREF_ABI(float), + // DECL_RANK_2_MEMREF_ABI(float), + // DECL_RANK_2_MEMREF_ABI(float) +); + +TEST(AotCompiled, ExecuteTest) { + + cnpy::npz_t reference_tensors = cnpy::npz_load( + "//##REFERENCE_TENSORS_PATH##//" + // "test/AotCompile/_internal_add_mul_single_output_reference_tensors.npz" + ); + + //##READ_REFERENCE_TENSORS_INTO_NPY_ARRAY##// + // cnpy::NpyArray refInput0 = reference_tensors["Input0"]; + // cnpy::NpyArray refInput1 = reference_tensors["Input1"]; + // cnpy::NpyArray refInput2 = reference_tensors["Input2"]; + // cnpy::NpyArray refOutput0 = reference_tensors["Output0"]; + + //##CREATE_MEMREF_FROM_NPY_ARRAY##// + // StridedMemRefType Input0 = + // CreateMemRefFromNpyArray(refInput0); + // StridedMemRefType Input1 = + // CreateMemRefFromNpyArray(refInput1); + // StridedMemRefType Input2 = + // CreateMemRefFromNpyArray(refInput2); + + OutputMemRefDescriptor Result = func_main( + //##PASS_INPUT_MEMREF_ARGUMENTS##// + // PASS_RANK_2_MEMREF(Input0), + // PASS_RANK_2_MEMREF(Input1), + // PASS_RANK_2_MEMREF(Input2) + ); + + //##ASSERT_RESULT_SHAPE_MATCHES_REFERENCE##// + // ASSERT_EQ(Result.Output0.sizes[0], refOutput0.shape[0]); + // ASSERT_EQ(Result.Output0.sizes[1], refOutput0.shape[1]); + + //##EXPECT_RESULT_DATA_MATCHES_REFERENCE##// + // for (int i = 0; i < refOutput0.num_vals; i++) + // EXPECT_FLOAT_EQ(Result.Output0.data[i], refOutput0.data()[i]); + + //##DEALLOCATE_RESULT_MEMREF##// + // free(Result.Output0.basePtr); +} diff --git a/tools/aot/execute_test_generator.py b/tools/aot/execute_test_generator.py new file mode 100644 index 0000000..7830599 --- /dev/null +++ b/tools/aot/execute_test_generator.py @@ -0,0 +1,138 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import re +import argparse +import numpy as np + +parser = argparse.ArgumentParser( + description="Test generator for AOT compiled programs using a test template" +) +parser.add_argument( + "--test_template_path", + required=True, + help="Path to the execute_test.template.cpp file", +) +parser.add_argument( + "--reference_tensors_path", + required=True, + help="Path to the file containing the reference inputs and outputs (.npz)", +) + +NUMPY_TO_MEMREF_DTYPE_MAP = { + # Add more mappings as needed + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "int32": "int32_t", + "uint32": "uint32_t", + "int64": "int64_t", + "uint64": "uint64_t", + "float32": "float", + "float64": "double", +} + +MEMREF_DTYPE_TO_GTEST_ASSERT_MAP = { + "int8_t": "EXPECT_EQ", + "uint8_t": "EXPECT_EQ", + "int16_t": "EXPECT_EQ", + "uint16_t": "EXPECT_EQ", + "int32_t": "EXPECT_EQ", + "uint32_t": "EXPECT_EQ", + "int64_t": "EXPECT_EQ", + "uint64_t": "EXPECT_EQ", + "float": "EXPECT_FLOAT_EQ", + "double": "EXPECT_DOUBLE_EQ", +} + + +def main(): + args = parser.parse_args() + + # Track string substitutions + input_memref_abi_declarations = [] + pass_input_memref_arguments = [] + create_memref_from_npy_array_str = "" + output_memref_variable_declarations_str = "" + assert_result_shape_matches_reference_str = "" + expect_result_data_matches_reference_str = "" + deallocate_result_memref_str = "" + read_reference_tensors_into_npy_array_str = "" + reference_tensors_path_str = args.reference_tensors_path.removeprefix( + "bazel-out/k8-fastbuild/bin/" + ) + + # Interpret function signature (num_args, rank, dtype, num_returns) + # from the saved reference tensors and build string substitutions + reference_tensors = np.load(args.reference_tensors_path) + for key in reference_tensors.keys(): + tensor = reference_tensors[key] + rank = tensor.ndim + dtype = NUMPY_TO_MEMREF_DTYPE_MAP[str(tensor.dtype)] + + if "Input" in key: + input_memref_abi_declarations.append( + f""" + DECL_RANK_{rank}_MEMREF_ABI({dtype})""" + ) + pass_input_memref_arguments.append( + f""" + PASS_RANK_{rank}_MEMREF({key})""" + ) + if rank == 0: + create_memref_from_npy_array_str += f""" + StridedMemRefType<{dtype}, {rank}> {key} = + CreateMemRefFromNpyArray<{dtype}>(ref{key});""" + else: + create_memref_from_npy_array_str += f""" + StridedMemRefType<{dtype}, {rank}> {key} = + CreateMemRefFromNpyArray<{dtype}, {rank}>(ref{key});""" + + if "Output" in key: + output_memref_variable_declarations_str += f""" + StridedMemRefType<{dtype}, {rank}> {key};""" + for n in range(rank): + assert_result_shape_matches_reference_str += f""" + ASSERT_EQ(Result.{key}.sizes[{n}], ref{key}.shape[{n}]);""" + expect_result_data_matches_reference_str += f""" + for (int i = 0; i < ref{key}.num_vals; i++) + {MEMREF_DTYPE_TO_GTEST_ASSERT_MAP[dtype]}(Result.{key}.data[i], ref{key}.data<{dtype}>()[i]);""" + deallocate_result_memref_str += f""" + free(Result.{key}.basePtr);""" + + read_reference_tensors_into_npy_array_str += f""" + cnpy::NpyArray ref{key} = reference_tensors["{key}"];""" + + # Comma separated, except at the end. + input_memref_abi_declarations_str = ",".join(input_memref_abi_declarations) + pass_input_memref_arguments_str = ",".join(pass_input_memref_arguments) + + substitutions = { + r"//##OUTPUT_MEMREF_VARIABLE_DECLARATIONS##//": output_memref_variable_declarations_str, + r"//##INPUT_MEMREF_ABI_DECLARATIONS##//": input_memref_abi_declarations_str, + r"//##REFERENCE_TENSORS_PATH##//": reference_tensors_path_str, + r"//##PASS_INPUT_MEMREF_ARGUMENTS##//": pass_input_memref_arguments_str, + r"//##READ_REFERENCE_TENSORS_INTO_NPY_ARRAY##//": read_reference_tensors_into_npy_array_str, + r"//##CREATE_MEMREF_FROM_NPY_ARRAY##//": create_memref_from_npy_array_str, + r"//##ASSERT_RESULT_SHAPE_MATCHES_REFERENCE##//": assert_result_shape_matches_reference_str, + r"//##EXPECT_RESULT_DATA_MATCHES_REFERENCE##//": expect_result_data_matches_reference_str, + r"//##DEALLOCATE_RESULT_MEMREF##//": deallocate_result_memref_str, + } + + # Open template test + with open(args.test_template_path, "r") as test_template_file: + test_template_source = test_template_file.read() + + # Perform the regex search and replace for each pattern + for pattern, replacement_string in substitutions.items(): + test_template_source = re.sub(pattern, replacement_string, test_template_source) + + # Important: This print is needed to pipe outputs in aot_compile's genrule + print(test_template_source) + + +if __name__ == "__main__": + main() diff --git a/tools/aot/torch_exporter_harness.py b/tools/aot/torch_exporter_harness.py new file mode 100644 index 0000000..c7a7c58 --- /dev/null +++ b/tools/aot/torch_exporter_harness.py @@ -0,0 +1,81 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import argparse +import importlib +import numpy as np + +import torch +from torch_mlir import fx + +from tools.aot.torch_loader_utils import TorchLoaderOutput + +parser = argparse.ArgumentParser( + description="Harness for running user provided torch_loader_lib to export Torch dialect programs" +) +parser.add_argument( + "--torch_loader_path", + required=True, + help="Full python import path (dot separated) to the torch loader function.", +) +parser.add_argument( + "--reference_tensors_path", + required=False, + help="Path to the file to save the reference inputs and outputs to (as .npz)", +) + + +def main(): + args = parser.parse_args() + loader_module, loader_function = args.torch_loader_path.rsplit(".", 1) + m = importlib.import_module(loader_module) + loader_result = getattr(m, loader_function)() + + assert isinstance( + loader_result, TorchLoaderOutput + ), "Please use tools.aot.torch_loader_utils.TorchLoaderOutput as your torch_loader function's return type" + assert isinstance( + loader_result.inputs, list + ), "Please provide List[torch.Tensor] as TorchLoaderOutput.inputs in your torch_loader function" + + # Used by gen_{name}_mlir_torch genrule + if not args.reference_tensors_path: + # torch.export + fx_importer + torch_program = fx.export_and_import( + loader_result.model, + *loader_result.inputs, # unpack list of input tensors + constraints=loader_result.constraints, + func_name=loader_result.func_name, + ) + + # Important: This print is needed to pipe outputs in aot_compile's genrule + print(torch_program) + + # Used by gen_{name}_reference_tensors genrule + else: + # Feed sample inputs to the model to get reference outputs + reference_outputs = loader_result.model(*loader_result.inputs) + + reference_tensors = {} + + # Collect reference inputs + for i, reference_input in enumerate(loader_result.inputs): + reference_tensors[f"Input{i}"] = reference_input.numpy() + + # Collect reference outputs + if isinstance(reference_outputs, tuple): + # output is a tuple of torch.Tensor's + for i, reference_output in enumerate(reference_outputs): + reference_tensors[f"Output{i}"] = reference_output.numpy() + elif isinstance(reference_outputs, torch.Tensor): + # output is a single torch.Tensor + reference_tensors["Output0"] = reference_outputs.numpy() + + # Save reference tensors as numpy archive (.npz) + np.savez(args.reference_tensors_path, **reference_tensors) + + +if __name__ == "__main__": + main() diff --git a/tools/aot/torch_loader_utils.py b/tools/aot/torch_loader_utils.py new file mode 100644 index 0000000..0b790e2 --- /dev/null +++ b/tools/aot/torch_loader_utils.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import List, NamedTuple, Optional + +import torch + + +class TorchLoaderOutput(NamedTuple): + model: torch.nn.Module + inputs: List[torch.Tensor] + constraints: Optional[List[torch.export.dynamic_dim]] = None + func_name: Optional[str] = "func_main" diff --git a/tools/clangd/BUILD b/tools/clangd/BUILD index 2e2c71b..ff64f3d 100644 --- a/tools/clangd/BUILD +++ b/tools/clangd/BUILD @@ -25,8 +25,8 @@ refresh_compile_commands( "//:TcpTypesIncGen", "//:TorchToTcp", "//:tcp-opt", - "//test/AotCompile:aot_compiled_basic_tcp_ops", - "//test/AotCompile:test_aot_compiled_basic_tcp_ops", + "//test/AotCompile/...", "//tools/aot:abi", + "//tools/aot:execute_test_template", ], )