Skip to content

Commit

Permalink
Implement array reshape for CUDA (#47)
Browse files Browse the repository at this point in the history
- Added a common util function for linking usage: `link_to_library_functions()`.
- Newly-added tests.

---------

Co-authored-by: Graham Markall <[email protected]>
  • Loading branch information
dlee992 and gmarkall authored Dec 31, 2024
1 parent 22bb24d commit 7396cf9
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 14 deletions.
46 changes: 32 additions & 14 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
'hrcp', 'hrint',
'htrunc', 'hdiv']

reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape']


class _Kernel(serialize.ReduceMixin):
'''
Expand Down Expand Up @@ -117,25 +119,43 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
if not link:
link = []

asm = lib.get_asm_str()

# A kernel needs cooperative launch if grid_sync is being used.
self.cooperative = 'cudaCGGetIntrinsicHandle' in lib.get_asm_str()
self.cooperative = 'cudaCGGetIntrinsicHandle' in asm
# We need to link against cudadevrt if grid sync is being used.
if self.cooperative:
lib.needs_cudadevrt = True

basedir = os.path.dirname(os.path.abspath(__file__))
asm = lib.get_asm_str()
def link_to_library_functions(library_functions, library_path,
prefix=None):
"""
Dynamically links to library functions by searching for their names
in the specified library and linking to the corresponding source
file.
"""
if prefix is not None:
library_functions = [f"{prefix}{fn}" for fn in
library_functions]

res = [fn for fn in cuda_fp16_math_funcs
if (f'__numba_wrapper_{fn}' in asm)]
found_functions = [fn for fn in library_functions
if f'{fn}' in asm]

if res:
# Path to the source containing the foreign function
functions_cu_path = os.path.join(basedir,
'cpp_function_wrappers.cu')
link.append(functions_cu_path)
if found_functions:
basedir = os.path.dirname(os.path.abspath(__file__))
source_file_path = os.path.join(basedir, library_path)
link.append(source_file_path)

link = self.maybe_link_nrt(link, tgt_ctx, asm)
return found_functions

# Link to the helper library functions if needed
link_to_library_functions(reshape_funcs, 'reshape_funcs.cu')
# Link to the CUDA FP16 math library functions if needed
link_to_library_functions(cuda_fp16_math_funcs,
'cpp_function_wrappers.cu',
'__numba_wrapper_')

self.maybe_link_nrt(link, tgt_ctx, asm)

for filepath in link:
lib.add_linking_file(filepath)
Expand All @@ -160,7 +180,7 @@ def __init__(self, py_func, argtypes, link=None, debug=False,

def maybe_link_nrt(self, link, tgt_ctx, asm):
if not tgt_ctx.enable_nrt:
return link
return

all_nrt = "|".join(self.NRT_functions)
pattern = (
Expand All @@ -175,8 +195,6 @@ def maybe_link_nrt(self, link, tgt_ctx, asm):
nrt_path = os.path.join(basedir, 'runtime', 'nrt.cu')
link.append(nrt_path)

return link

@property
def library(self):
return self._codelibrary
Expand Down
151 changes: 151 additions & 0 deletions numba_cuda/numba/cuda/reshape_funcs.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Handle reshaping of zero-sized array.
* See numba_attempt_nocopy_reshape() below.
*/
#define NPY_MAXDIMS 32

typedef long long int npy_intp;

extern "C" __device__ int
nocopy_empty_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
npy_intp newnd, const npy_intp *newdims,
npy_intp *newstrides, npy_intp itemsize,
int is_f_order)
{
int i;
/* Just make the strides vaguely reasonable
* (they can have any value in theory).
*/
for (i = 0; i < newnd; i++)
newstrides[i] = itemsize;
return 1; /* reshape successful */
}

/*
* Straight from Numpy's _attempt_nocopy_reshape()
* (np/core/src/multiarray/shape.c).
* Attempt to reshape an array without copying data
*
* This function should correctly handle all reshapes, including
* axes of length 1. Zero strides should work but are untested.
*
* If a copy is needed, returns 0
* If no copy is needed, returns 1 and fills `npy_intp *newstrides`
* with appropriate strides
*/
extern "C" __device__ int
numba_attempt_nocopy_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
npy_intp newnd, const npy_intp *newdims,
npy_intp *newstrides, npy_intp itemsize,
int is_f_order)
{
int oldnd;
npy_intp olddims[NPY_MAXDIMS];
npy_intp oldstrides[NPY_MAXDIMS];
npy_intp np, op, last_stride;
int oi, oj, ok, ni, nj, nk;

oldnd = 0;
/*
* Remove axes with dimension 1 from the old array. They have no effect
* but would need special cases since their strides do not matter.
*/
for (oi = 0; oi < nd; oi++) {
if (dims[oi]!= 1) {
olddims[oldnd] = dims[oi];
oldstrides[oldnd] = strides[oi];
oldnd++;
}
}

np = 1;
for (ni = 0; ni < newnd; ni++) {
np *= newdims[ni];
}
op = 1;
for (oi = 0; oi < oldnd; oi++) {
op *= olddims[oi];
}
if (np != op) {
/* different total sizes; no hope */
return 0;
}

if (np == 0) {
/* the Numpy code does not handle 0-sized arrays */
return nocopy_empty_reshape(nd, dims, strides,
newnd, newdims, newstrides,
itemsize, is_f_order);
}

/* oi to oj and ni to nj give the axis ranges currently worked with */
oi = 0;
oj = 1;
ni = 0;
nj = 1;
while (ni < newnd && oi < oldnd) {
np = newdims[ni];
op = olddims[oi];

while (np != op) {
if (np < op) {
/* Misses trailing 1s, these are handled later */
np *= newdims[nj++];
} else {
op *= olddims[oj++];
}
}

/* Check whether the original axes can be combined */
for (ok = oi; ok < oj - 1; ok++) {
if (is_f_order) {
if (oldstrides[ok+1] != olddims[ok]*oldstrides[ok]) {
/* not contiguous enough */
return 0;
}
}
else {
/* C order */
if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]) {
/* not contiguous enough */
return 0;
}
}
}

/* Calculate new strides for all axes currently worked with */
if (is_f_order) {
newstrides[ni] = oldstrides[oi];
for (nk = ni + 1; nk < nj; nk++) {
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1];
}
}
else {
/* C order */
newstrides[nj - 1] = oldstrides[oj - 1];
for (nk = nj - 1; nk > ni; nk--) {
newstrides[nk - 1] = newstrides[nk]*newdims[nk];
}
}
ni = nj++;
oi = oj++;
}

/*
* Set strides corresponding to trailing 1s of the new shape.
*/
if (ni >= 1) {
last_stride = newstrides[ni - 1];
}
else {
last_stride = itemsize;
}
if (is_f_order) {
last_stride *= newdims[ni - 1];
}
for (nk = ni; nk < newnd; nk++) {
newstrides[nk] = last_stride;
}

return 1;
}
73 changes: 73 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
cuda.pinned_array_like)


def array_reshape1d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
got[i] = y[i]


def array_reshape2d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
for j in range(y.shape[1]):
got[i, j] = y[i, j]


def array_reshape3d(arr, newshape, got):
y = arr.reshape(newshape)
for i in range(y.shape[0]):
for j in range(y.shape[1]):
for k in range(y.shape[2]):
got[i, j, k] = y[i, j, k]


def array_reshape(arr, newshape):
return arr.reshape(newshape)


class TestCudaArray(CUDATestCase):
def test_gpu_array_zero_length(self):
x = np.arange(0)
Expand Down Expand Up @@ -255,6 +280,54 @@ def func(A, out):

self.assertEqual(1, len(func.overloads))

def test_array_reshape(self):
def check(pyfunc, kernelfunc, arr, shape):
kernel = cuda.jit(kernelfunc)
expected = pyfunc(arr, shape)
got = np.zeros(expected.shape, dtype=arr.dtype)
kernel[1, 1](arr, shape, got)
self.assertPreciseEqual(got, expected)

def check_only_shape(kernelfunc, arr, shape, expected_shape):
kernel = cuda.jit(kernelfunc)
got = np.zeros(expected_shape, dtype=arr.dtype)
kernel[1, 1](arr, shape, got)
self.assertEqual(got.shape, expected_shape)
self.assertEqual(got.size, arr.size)

# 0-sized arrays
def check_empty(arr):
check(array_reshape, array_reshape1d, arr, 0)
check(array_reshape, array_reshape1d, arr, (0,))
check(array_reshape, array_reshape3d, arr, (1, 0, 2))
check_only_shape(array_reshape2d, arr, (0, -1), (0, 0))
check_only_shape(array_reshape2d, arr, (4, -1), (4, 0))
check_only_shape(array_reshape3d, arr, (-1, 0, 4), (0, 0, 4))

# C-contiguous
arr = np.arange(24)
check(array_reshape, array_reshape1d, arr, (24,))
check(array_reshape, array_reshape2d, arr, (4, 6))
check(array_reshape, array_reshape2d, arr, (8, 3))
check(array_reshape, array_reshape3d, arr, (8, 1, 3))

arr = np.arange(24).reshape((1, 8, 1, 1, 3, 1))
check(array_reshape, array_reshape1d, arr, (24,))
check(array_reshape, array_reshape2d, arr, (4, 6))
check(array_reshape, array_reshape2d, arr, (8, 3))
check(array_reshape, array_reshape3d, arr, (8, 1, 3))

# Test negative shape value
arr = np.arange(25).reshape(5,5)
check(array_reshape, array_reshape1d, arr, -1)
check(array_reshape, array_reshape1d, arr, (-1,))
check(array_reshape, array_reshape2d, arr, (-1, 5))
check(array_reshape, array_reshape3d, arr, (5, -1, 5))
check(array_reshape, array_reshape3d, arr, (5, 5, -1))

arr = np.array([])
check_empty(arr)


if __name__ == '__main__':
unittest.main()

0 comments on commit 7396cf9

Please sign in to comment.