-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a self contained python calling MatX (calling python calling MatX…
…) integration example
- Loading branch information
1 parent
a63cd2f
commit 12724da
Showing
4 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# This is a cmake project showing how to build a python importable library | ||
# using pybind11, how to pass tensors between MatX and python, and | ||
# how to call MatX operators from python | ||
|
||
cmake_minimum_required(VERSION 3.26) | ||
|
||
if(NOT DEFINED CMAKE_BUILD_TYPE) | ||
message(WARNING "CMAKE_BUILD_TYPE not defined. Defaulting to release.") | ||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type: Debug;Release;MinSizeRel;RelWithDebInfo") | ||
endif() | ||
|
||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) | ||
message(WARNING "CMAKE_CUDA_ARCHITECTURES not defined. Defaulting to 70") | ||
set(CMAKE_CUDA_ARCHITECTURES 70 CACHE STRING "Select compile target CUDA Compute Capabilities") | ||
endif() | ||
|
||
if(NOT DEFINED MATX_FETCH_REMOTE) | ||
message(WARNING "MATX_FETCH_REMOTE not defined. Defaulting to OFF, will use local MatX repo") | ||
set(MATX_FETCH_REMOTE OFF CACHE BOOL "Set MatX repo fetch location") | ||
endif() | ||
|
||
project(SAMPLE_MATX_PYTHON LANGUAGES CUDA CXX) | ||
find_package(CUDAToolkit 12.6 REQUIRED) | ||
|
||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
|
||
# Must enable pybind11 support | ||
set(MATX_EN_PYBIND11 ON) | ||
|
||
# Use this section if you want to configure other MatX options | ||
#set(MATX_EN_VISUALIZATION ON) # Uncomment to enable visualizations | ||
#set(MATX_EN_FILEIO ON) # Uncomment to file IO | ||
|
||
if(MATX_FETCH_REMOTE) | ||
include(FetchContent) | ||
FetchContent_Declare( | ||
MatX | ||
GIT_REPOSITORY https://github.com/NVIDIA/MatX.git | ||
GIT_TAG main | ||
) | ||
else() | ||
include(FetchContent) | ||
FetchContent_Declare( | ||
MatX | ||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../ | ||
) | ||
endif() | ||
FetchContent_MakeAvailable(MatX) | ||
|
||
add_library(matxutil MODULE matxutil.cu) | ||
target_link_libraries(matxutil PRIVATE matx::matx) | ||
set_target_properties(matxutil PROPERTIES SUFFIX ".so" PREFIX "") | ||
|
||
configure_file( | ||
${CMAKE_CURRENT_SOURCE_DIR}/mypythonlib.py | ||
${CMAKE_BINARY_DIR} | ||
COPYONLY | ||
) | ||
|
||
configure_file( | ||
${CMAKE_CURRENT_SOURCE_DIR}/example_matxutil.py | ||
${CMAKE_BINARY_DIR} | ||
COPYONLY | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import cupy as cp | ||
import sys | ||
sys.path.append('.') | ||
|
||
import matxutil | ||
|
||
# Demonstrate dlpack consumption invalidates it for future use | ||
def dlp_usage_error(): | ||
a = cp.empty((3,3), dtype=cp.float32) | ||
dlp = a.toDlpack() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
a2 = cp.from_dlpack(dlp) # causes dlp to become unused | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
return dlp | ||
|
||
# Demonstrate cupy array stays in scope when returning valid dlp | ||
def scope_okay(): | ||
a = cp.empty((3,3), dtype=cp.float32) | ||
a[1,1] = 2 | ||
dlp = a.toDlpack() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
return dlp | ||
|
||
print("Demonstrate dlpack consumption invalidates it for future use:") | ||
dlp = dlp_usage_error() | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
print() | ||
|
||
print("Demonstrate cupy array stays in scope when returning valid dlpack:") | ||
dlp = scope_okay() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
print() | ||
|
||
print("Print info about the dlpack:") | ||
matxutil.print_dlpack_info(dlp) | ||
print() | ||
|
||
print("Use MatX to print the tensor:") | ||
matxutil.print_float_2D(dlp) | ||
print() | ||
|
||
print("Print current memory usage info:") | ||
gpu_mempool = cp.get_default_memory_pool() | ||
pinned_mempool = cp.get_default_pinned_memory_pool() | ||
print(f" GPU mempool used bytes {gpu_mempool.used_bytes()}") | ||
print(f" Pinned mempool n_free_blocks {pinned_mempool.n_free_blocks()}") | ||
print() | ||
|
||
print("Demonstrate python to C++ to python to C++ calling chain (uses mypythonlib.py):") | ||
# This function calls back into python and executes a from_dlpack, consuming the dlp | ||
matxutil.call_python_example(dlp) | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
|
||
# Other things to try | ||
# (Done) Check dltensor still valid in C++ before use | ||
# (Done) Assign dltensor to MatX tensor, how to determine number of dimensions | ||
# Pass stream from cp to C++ | ||
# (Done) Negative case where pointer goes out of scope before C++ finishes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
//////////////////////////////////////////////////////////////////////////////// | ||
// BSD 3-Clause License | ||
// | ||
// Copyright (c) 2024, NVIDIA Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// | ||
// 1. Redistributions of source code must retain the above copyright notice, this | ||
// list of conditions and the following disclaimer. | ||
// | ||
// 2. Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// 3. Neither the name of the copyright holder nor the names of its | ||
// contributors may be used to endorse or promote products derived from | ||
// this software without specific prior written permission. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
///////////////////////////////////////////////////////////////////////////////// | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <iostream> | ||
#include <stdio.h> | ||
#include <matx.h> | ||
#include <matx/core/dlpack.h> | ||
|
||
namespace py = pybind11; | ||
|
||
const char* get_capsule_name(py::capsule capsule) | ||
{ | ||
return capsule.name(); | ||
} | ||
|
||
typedef DLManagedTensor* PTR_DLManagedTensor; | ||
int attempt_unpack_dlpack(py::capsule dlpack_capsule, PTR_DLManagedTensor& p_dlpack) | ||
{ | ||
if (p_dlpack == nullptr) | ||
{ | ||
return -1; | ||
} | ||
|
||
const char* capsule_name = dlpack_capsule.name(); | ||
|
||
if (strncmp(capsule_name,"dltensor",8) != 0) | ||
{ | ||
return -2; | ||
} | ||
|
||
p_dlpack = static_cast<PTR_DLManagedTensor>(dlpack_capsule.get_pointer()); | ||
|
||
if (p_dlpack == nullptr) { | ||
return -3; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
int check_dlpack_status(py::capsule dlpack_capsule) | ||
{ | ||
PTR_DLManagedTensor unused; | ||
return attempt_unpack_dlpack(dlpack_capsule, unused); | ||
} | ||
|
||
const char* dlpack_device_type_to_string(DLDeviceType device_type) | ||
{ | ||
switch(device_type) | ||
{ | ||
case kDLCPU: return "kDLCPU"; | ||
case kDLCUDA: return "kDLCUDA"; | ||
case kDLCUDAHost: return "kDLCUDAHost"; | ||
case kDLOpenCL: return "kDLOpenCL"; | ||
case kDLVulkan: return "kDLVulkan"; | ||
case kDLMetal: return "kDLMetal"; | ||
case kDLVPI: return "kDLVPI"; | ||
case kDLROCM: return "kDLROCM"; | ||
case kDLROCMHost: return "kDLROCMHost"; | ||
case kDLExtDev: return "kDLExtDev"; | ||
case kDLCUDAManaged: return "kDLCUDAManaged"; | ||
case kDLOneAPI: return "kDLOneAPI"; | ||
case kDLWebGPU: return "kDLWebGPU"; | ||
case kDLHexagon: return "kDLHexagon"; | ||
default: return "Unknown DLDeviceType"; | ||
} | ||
} | ||
|
||
const char* dlpack_code_to_string(uint8_t code) | ||
{ | ||
switch(code) | ||
{ | ||
case kDLInt: return "kDLInt"; | ||
case kDLUInt: return "kDLUInt"; | ||
case kDLFloat: return "kDLFloat"; | ||
case kDLOpaqueHandle: return "kDLOpaqueHandle"; | ||
case kDLBfloat: return "kDLBfloat"; | ||
case kDLComplex: return "kDLComplex"; | ||
case kDLBool: return "kDLBool"; | ||
default: return "Unknown DLDataTypeCode"; | ||
} | ||
} | ||
|
||
void print_dlpack_info(py::capsule dlpack_capsule) { | ||
PTR_DLManagedTensor p_tensor; | ||
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor)) | ||
{ | ||
fprintf(stderr,"Error: capsule not valid dlpack"); | ||
return; | ||
} | ||
|
||
printf(" data: %p\n",p_tensor->dl_tensor.data); | ||
printf(" device: device_type %s, device_id %d\n", | ||
dlpack_device_type_to_string(p_tensor->dl_tensor.device.device_type), | ||
p_tensor->dl_tensor.device.device_id | ||
); | ||
printf(" ndim: %d\n",p_tensor->dl_tensor.ndim); | ||
printf(" dtype: code %s, bits %u, lanes %u\n", | ||
dlpack_code_to_string(p_tensor->dl_tensor.dtype.code), | ||
p_tensor->dl_tensor.dtype.bits, | ||
p_tensor->dl_tensor.dtype.lanes | ||
); | ||
printf(" shape: "); | ||
for (int k=0; k<p_tensor->dl_tensor.ndim; k++) | ||
{ | ||
printf("%ld, ",p_tensor->dl_tensor.shape[k]); | ||
} | ||
printf("\n"); | ||
printf(" strides: "); | ||
for (int k=0; k<p_tensor->dl_tensor.ndim; k++) | ||
{ | ||
printf("%ld, ",p_tensor->dl_tensor.strides[k]); | ||
} | ||
printf("\n"); | ||
printf(" byte_offset: %lu\n",p_tensor->dl_tensor.byte_offset); | ||
} | ||
|
||
template<typename T, int RANK> | ||
void print(py::capsule dlpack_capsule) | ||
{ | ||
PTR_DLManagedTensor p_tensor; | ||
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor)) | ||
{ | ||
fprintf(stderr,"Error: capsule not valid dlpack"); | ||
return; | ||
} | ||
|
||
matx::tensor_t<T, RANK> a; | ||
matx::make_tensor(a, *p_tensor); | ||
matx::print(a); | ||
} | ||
|
||
void call_python_example(py::capsule dlpack_capsule) | ||
{ | ||
PTR_DLManagedTensor p_tensor; | ||
if (attempt_unpack_dlpack(dlpack_capsule, p_tensor)) | ||
{ | ||
fprintf(stderr,"Error: capsule not valid dlpack"); | ||
return; | ||
} | ||
|
||
matx::tensor_t<float, 2> a; | ||
matx::make_tensor(a, *p_tensor); | ||
|
||
auto pb = matx::detail::MatXPybind{}; | ||
|
||
// Example use of python's print | ||
pybind11::print(" Example use of python's print function from C++: ", 1, 2.0, "three"); | ||
pybind11::print(" The dlpack_capsule is a ", dlpack_capsule); | ||
|
||
auto mypythonlib = pybind11::module_::import("mypythonlib"); | ||
mypythonlib.attr("my_func")(dlpack_capsule); | ||
} | ||
|
||
PYBIND11_MODULE(matxutil, m) { | ||
m.def("get_capsule_name", &get_capsule_name, "Returns PyCapsule name"); | ||
m.def("print_dlpack_info", &print_dlpack_info, "Print the DLPack tensor metadata"); | ||
m.def("check_dlpack_status", &check_dlpack_status, "Returns 0 if DLPack is valid, negative error code otherwise"); | ||
m.def("print_float_2D", &print<float,2>, "Prints a float32 2D tensor"); | ||
m.def("call_python_example", &call_python_example, "Example C++ function that calls python code"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import cupy as cp | ||
import sys | ||
sys.path.append('.') | ||
import matxutil | ||
|
||
def my_func(dlp): | ||
print(f" type(dlp) before cp.from_dlpack(): {type(dlp)}") | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
a = cp.from_dlpack(dlp) | ||
print(f" type(dlp) after cp.from_dlpack(): {type(dlp)}") | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
print(f" type(cp.from_dlPack(dlp)): {type(a)}") | ||
print() | ||
print("Finally, print the tensor we received from MatX using python:") | ||
print(a) |