Skip to content

Commit 7928f2d

Browse files
dlee992gmarkall
authored andcommitted
copy from numba PR #8458
1 parent 9ed01c5 commit 7928f2d

File tree

3 files changed

+254
-9
lines changed

3 files changed

+254
-9
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
'hrcp', 'hrint',
3737
'htrunc', 'hdiv']
3838

39+
reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape']
40+
3941

4042
class _Kernel(serialize.ReduceMixin):
4143
'''
@@ -105,15 +107,33 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
105107
if self.cooperative:
106108
lib.needs_cudadevrt = True
107109

108-
res = [fn for fn in cuda_fp16_math_funcs
109-
if (f'__numba_wrapper_{fn}' in lib.get_asm_str())]
110-
111-
if res:
112-
# Path to the source containing the foreign function
113-
basedir = os.path.dirname(os.path.abspath(__file__))
114-
functions_cu_path = os.path.join(basedir,
115-
'cpp_function_wrappers.cu')
116-
link.append(functions_cu_path)
110+
def link_to_library_functions(library_functions, library_path,
111+
prefix=None):
112+
"""
113+
Dynamically links to library functions by searching for their names
114+
in the specified library and linking to the corresponding source
115+
file.
116+
"""
117+
if prefix is not None:
118+
library_functions = [f"{prefix}{fn}" for fn in
119+
library_functions]
120+
121+
found_functions = [fn for fn in library_functions
122+
if f'{fn}' in lib.get_asm_str()]
123+
124+
if found_functions:
125+
basedir = os.path.dirname(os.path.abspath(__file__))
126+
source_file_path = os.path.join(basedir, library_path)
127+
link.append(source_file_path)
128+
129+
return found_functions
130+
131+
# Link to the helper library functions if needed
132+
link_to_library_functions(reshape_funcs, 'reshape_funcs.cu')
133+
# Link to the CUDA FP16 math library functions if needed
134+
link_to_library_functions(cuda_fp16_math_funcs,
135+
'cpp_function_wrappers.cu',
136+
'__numba_wrapper_')
117137

118138
for filepath in link:
119139
lib.add_linking_file(filepath)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Handle reshaping of zero-sized array.
3+
* See numba_attempt_nocopy_reshape() below.
4+
*/
5+
#define NPY_MAXDIMS 32
6+
7+
typedef long int npy_intp;
8+
9+
extern "C" __device__ int
10+
nocopy_empty_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
11+
npy_intp newnd, const npy_intp *newdims,
12+
npy_intp *newstrides, npy_intp itemsize,
13+
int is_f_order)
14+
{
15+
int i;
16+
/* Just make the strides vaguely reasonable
17+
* (they can have any value in theory).
18+
*/
19+
for (i = 0; i < newnd; i++)
20+
newstrides[i] = itemsize;
21+
return 1; /* reshape successful */
22+
}
23+
24+
/*
25+
* Straight from Numpy's _attempt_nocopy_reshape()
26+
* (np/core/src/multiarray/shape.c).
27+
* Attempt to reshape an array without copying data
28+
*
29+
* This function should correctly handle all reshapes, including
30+
* axes of length 1. Zero strides should work but are untested.
31+
*
32+
* If a copy is needed, returns 0
33+
* If no copy is needed, returns 1 and fills `npy_intp *newstrides`
34+
* with appropriate strides
35+
*/
36+
extern "C" __device__ int
37+
numba_attempt_nocopy_reshape(npy_intp nd, const npy_intp *dims, const npy_intp *strides,
38+
npy_intp newnd, const npy_intp *newdims,
39+
npy_intp *newstrides, npy_intp itemsize,
40+
int is_f_order)
41+
{
42+
int oldnd;
43+
npy_intp olddims[NPY_MAXDIMS];
44+
npy_intp oldstrides[NPY_MAXDIMS];
45+
npy_intp np, op, last_stride;
46+
int oi, oj, ok, ni, nj, nk;
47+
48+
oldnd = 0;
49+
/*
50+
* Remove axes with dimension 1 from the old array. They have no effect
51+
* but would need special cases since their strides do not matter.
52+
*/
53+
for (oi = 0; oi < nd; oi++) {
54+
if (dims[oi]!= 1) {
55+
olddims[oldnd] = dims[oi];
56+
oldstrides[oldnd] = strides[oi];
57+
oldnd++;
58+
}
59+
}
60+
61+
np = 1;
62+
for (ni = 0; ni < newnd; ni++) {
63+
np *= newdims[ni];
64+
}
65+
op = 1;
66+
for (oi = 0; oi < oldnd; oi++) {
67+
op *= olddims[oi];
68+
}
69+
if (np != op) {
70+
/* different total sizes; no hope */
71+
return 0;
72+
}
73+
74+
if (np == 0) {
75+
/* the Numpy code does not handle 0-sized arrays */
76+
return nocopy_empty_reshape(nd, dims, strides,
77+
newnd, newdims, newstrides,
78+
itemsize, is_f_order);
79+
}
80+
81+
/* oi to oj and ni to nj give the axis ranges currently worked with */
82+
oi = 0;
83+
oj = 1;
84+
ni = 0;
85+
nj = 1;
86+
while (ni < newnd && oi < oldnd) {
87+
np = newdims[ni];
88+
op = olddims[oi];
89+
90+
while (np != op) {
91+
if (np < op) {
92+
/* Misses trailing 1s, these are handled later */
93+
np *= newdims[nj++];
94+
} else {
95+
op *= olddims[oj++];
96+
}
97+
}
98+
99+
/* Check whether the original axes can be combined */
100+
for (ok = oi; ok < oj - 1; ok++) {
101+
if (is_f_order) {
102+
if (oldstrides[ok+1] != olddims[ok]*oldstrides[ok]) {
103+
/* not contiguous enough */
104+
return 0;
105+
}
106+
}
107+
else {
108+
/* C order */
109+
if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]) {
110+
/* not contiguous enough */
111+
return 0;
112+
}
113+
}
114+
}
115+
116+
/* Calculate new strides for all axes currently worked with */
117+
if (is_f_order) {
118+
newstrides[ni] = oldstrides[oi];
119+
for (nk = ni + 1; nk < nj; nk++) {
120+
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1];
121+
}
122+
}
123+
else {
124+
/* C order */
125+
newstrides[nj - 1] = oldstrides[oj - 1];
126+
for (nk = nj - 1; nk > ni; nk--) {
127+
newstrides[nk - 1] = newstrides[nk]*newdims[nk];
128+
}
129+
}
130+
ni = nj++;
131+
oi = oj++;
132+
}
133+
134+
/*
135+
* Set strides corresponding to trailing 1s of the new shape.
136+
*/
137+
if (ni >= 1) {
138+
last_stride = newstrides[ni - 1];
139+
}
140+
else {
141+
last_stride = itemsize;
142+
}
143+
if (is_f_order) {
144+
last_stride *= newdims[ni - 1];
145+
}
146+
for (nk = ni; nk < newnd; nk++) {
147+
newstrides[nk] = last_stride;
148+
}
149+
150+
return 1;
151+
}

numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,31 @@
99
from unittest.mock import call, patch
1010

1111

12+
def array_reshape1d(arr, newshape, got):
13+
y = arr.reshape(newshape)
14+
for i in range(y.shape[0]):
15+
got[i] = y[i]
16+
17+
18+
def array_reshape2d(arr, newshape, got):
19+
y = arr.reshape(newshape)
20+
for i in range(y.shape[0]):
21+
for j in range(y.shape[1]):
22+
got[i, j] = y[i, j]
23+
24+
25+
def array_reshape3d(arr, newshape, got):
26+
y = arr.reshape(newshape)
27+
for i in range(y.shape[0]):
28+
for j in range(y.shape[1]):
29+
for k in range(y.shape[2]):
30+
got[i, j, k] = y[i, j, k]
31+
32+
33+
def array_reshape(arr, newshape):
34+
return arr.reshape(newshape)
35+
36+
1237
@skip_on_cudasim('CUDA Array Interface is not supported in the simulator')
1338
class TestCudaArrayInterface(ContextResettingTestCase):
1439
def assertPointersEqual(self, a, b):
@@ -430,6 +455,55 @@ def f(x, y):
430455
# Ensure that synchronize was not called
431456
mock_sync.assert_not_called()
432457

458+
# @skip_unless_cuda_python('NVIDIA Binding needed for NVRTC')
459+
def test_array_reshape(self):
460+
def check(pyfunc, kernelfunc, arr, shape):
461+
kernel = cuda.jit(kernelfunc)
462+
expected = pyfunc(arr, shape)
463+
got = np.zeros(expected.shape, dtype=arr.dtype)
464+
kernel[1, 1](arr, shape, got)
465+
self.assertPreciseEqual(got, expected)
466+
467+
def check_only_shape(kernelfunc, arr, shape, expected_shape):
468+
kernel = cuda.jit(kernelfunc)
469+
got = np.zeros(expected_shape, dtype=arr.dtype)
470+
kernel[1, 1](arr, shape, got)
471+
self.assertEqual(got.shape, expected_shape)
472+
self.assertEqual(got.size, arr.size)
473+
474+
# 0-sized arrays
475+
def check_empty(arr):
476+
check(array_reshape, array_reshape1d, arr, 0)
477+
check(array_reshape, array_reshape1d, arr, (0,))
478+
check(array_reshape, array_reshape3d, arr, (1, 0, 2))
479+
check_only_shape(array_reshape2d, arr, (0, -1), (0, 0))
480+
check_only_shape(array_reshape2d, arr, (4, -1), (4, 0))
481+
check_only_shape(array_reshape3d, arr, (-1, 0, 4), (0, 0, 4))
482+
483+
# C-contiguous
484+
arr = np.arange(24)
485+
check(array_reshape, array_reshape1d, arr, (24,))
486+
check(array_reshape, array_reshape2d, arr, (4, 6))
487+
check(array_reshape, array_reshape2d, arr, (8, 3))
488+
check(array_reshape, array_reshape3d, arr, (8, 1, 3))
489+
490+
arr = np.arange(24).reshape((1, 8, 1, 1, 3, 1))
491+
check(array_reshape, array_reshape1d, arr, (24,))
492+
check(array_reshape, array_reshape2d, arr, (4, 6))
493+
check(array_reshape, array_reshape2d, arr, (8, 3))
494+
check(array_reshape, array_reshape3d, arr, (8, 1, 3))
495+
496+
# Test negative shape value
497+
arr = np.arange(25).reshape(5,5)
498+
check(array_reshape, array_reshape1d, arr, -1)
499+
check(array_reshape, array_reshape1d, arr, (-1,))
500+
check(array_reshape, array_reshape2d, arr, (-1, 5))
501+
check(array_reshape, array_reshape3d, arr, (5, -1, 5))
502+
check(array_reshape, array_reshape3d, arr, (5, 5, -1))
503+
504+
arr = np.array([])
505+
check_empty(arr)
506+
433507

434508
if __name__ == "__main__":
435509
unittest.main()

0 commit comments

Comments
 (0)