-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a self contained python calling MatX (calling python calling MatX…
…) integration example
- Loading branch information
1 parent
db87e08
commit b044cee
Showing
5 changed files
with
398 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# This is a cmake project showing how to build a python importable library | ||
# using pybind11, how to pass tensors between MatX and python, and | ||
# how to call MatX operators from python | ||
|
||
cmake_minimum_required(VERSION 3.26) | ||
|
||
if(NOT DEFINED CMAKE_BUILD_TYPE) | ||
message(WARNING "CMAKE_BUILD_TYPE not defined. Defaulting to release.") | ||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type: Debug;Release;MinSizeRel;RelWithDebInfo") | ||
endif() | ||
|
||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) | ||
message(WARNING "CMAKE_CUDA_ARCHITECTURES not defined. Defaulting to 70") | ||
set(CMAKE_CUDA_ARCHITECTURES 70 CACHE STRING "Select compile target CUDA Compute Capabilities") | ||
endif() | ||
|
||
if(NOT DEFINED MATX_FETCH_REMOTE) | ||
message(WARNING "MATX_FETCH_REMOTE not defined. Defaulting to OFF, will use local MatX repo") | ||
set(MATX_FETCH_REMOTE OFF CACHE BOOL "Set MatX repo fetch location") | ||
endif() | ||
|
||
project(SAMPLE_MATX_PYTHON LANGUAGES CUDA CXX) | ||
find_package(CUDAToolkit 12.6 REQUIRED) | ||
|
||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
|
||
# Must enable pybind11 support | ||
set(MATX_EN_PYBIND11 ON) | ||
|
||
# Use this section if you want to configure other MatX options | ||
#set(MATX_EN_VISUALIZATION ON) # Uncomment to enable visualizations | ||
#set(MATX_EN_FILEIO ON) # Uncomment to file IO | ||
|
||
# Skip recursive MatX fetch | ||
if(MATX_BUILD_EXAMPLES) | ||
else() | ||
if(MATX_FETCH_REMOTE) | ||
include(FetchContent) | ||
FetchContent_Declare( | ||
MatX | ||
GIT_REPOSITORY https://github.com/NVIDIA/MatX.git | ||
GIT_TAG main | ||
) | ||
else() | ||
include(FetchContent) | ||
FetchContent_Declare( | ||
MatX | ||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../ | ||
) | ||
endif() | ||
FetchContent_MakeAvailable(MatX) | ||
endif() | ||
|
||
add_library(matxutil MODULE matxutil.cu) | ||
target_link_libraries(matxutil PRIVATE matx::matx CUDA::cudart) | ||
set_target_properties(matxutil PROPERTIES SUFFIX ".so" PREFIX "") | ||
|
||
configure_file( | ||
${CMAKE_CURRENT_SOURCE_DIR}/mypythonlib.py | ||
${CMAKE_BINARY_DIR} | ||
COPYONLY | ||
) | ||
|
||
configure_file( | ||
${CMAKE_CURRENT_SOURCE_DIR}/example_matxutil.py | ||
${CMAKE_BINARY_DIR} | ||
COPYONLY | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import cupy as cp | ||
import sys | ||
|
||
# Add path . if we built as a stand-alone project | ||
sys.path.append('.') | ||
|
||
# Add path examples/python_integration_sample/ if we built as part of MatX examples | ||
sys.path.append('examples/python_integration_sample/') | ||
|
||
import matxutil | ||
|
||
# Demonstrate dlpack consumption invalidates it for future use | ||
def dlp_usage_error(): | ||
a = cp.empty((3,3), dtype=cp.float32) | ||
dlp = a.toDlpack() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
a2 = cp.from_dlpack(dlp) # causes dlp to become unused | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
return dlp | ||
|
||
# Demonstrate cupy array stays in scope when returning valid dlp | ||
def scope_okay(): | ||
a = cp.empty((3,3), dtype=cp.float32) | ||
a[1,1] = 2 | ||
dlp = a.toDlpack() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
return dlp | ||
|
||
#Do all cupy work using the "with stream" context manager | ||
stream = cp.cuda.stream.Stream(non_blocking=True) | ||
with stream: | ||
print("Demonstrate dlpack consumption invalidates it for future use:") | ||
dlp = dlp_usage_error() | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
print() | ||
|
||
print("Demonstrate cupy array stays in scope when returning valid dlpack:") | ||
dlp = scope_okay() | ||
assert(matxutil.check_dlpack_status(dlp) == 0) | ||
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}") | ||
print() | ||
|
||
print("Print info about the dlpack:") | ||
matxutil.print_dlpack_info(dlp) | ||
print() | ||
|
||
print("Use MatX to print the tensor:") | ||
matxutil.print_float_2D(dlp) | ||
print() | ||
|
||
print("Print current memory usage info:") | ||
gpu_mempool = cp.get_default_memory_pool() | ||
pinned_mempool = cp.get_default_pinned_memory_pool() | ||
print(f" GPU mempool used bytes {gpu_mempool.used_bytes()}") | ||
print(f" Pinned mempool n_free_blocks {pinned_mempool.n_free_blocks()}") | ||
print() | ||
|
||
print("Demonstrate python to C++ to python to C++ calling chain (uses mypythonlib.py):") | ||
# This function calls back into python and executes a from_dlpack, consuming the dlp | ||
matxutil.call_python_example(dlp) | ||
assert(matxutil.check_dlpack_status(dlp) != 0) | ||
del dlp | ||
|
||
print("Demonstrate adding two tensors together using MatX:") | ||
a = cp.array([[1,2,3],[4,5,6],[7,8,9]], dtype=cp.float32) | ||
b = cp.array([[1,2,3],[4,5,6],[7,8,9]], dtype=cp.float32) | ||
c = cp.empty(b.shape, dtype=b.dtype) | ||
|
||
c_dlp = c.toDlpack() | ||
a_dlp = a.toDlpack() | ||
b_dlp = b.toDlpack() | ||
matxutil.add_float_2D(c_dlp, a_dlp, b_dlp, stream.ptr) | ||
stream.synchronize() | ||
print(f"Tensor a {a}") | ||
print(f"Tensor b {b}") | ||
print(f"Tensor c=a+b {c}") |
Oops, something went wrong.