Skip to content

Adjust float type of kernel functions used in propagate methods #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 49 additions & 20 deletions acc/components/ComponentBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,30 @@ def register_propagate_method(cls, propagate):
return new_propagate

@classmethod
def _adjust_propagate_type(cls, propagate):
def _change_kernel_floattype(cls, kernel_func):
'''
Changes the floattype of kernel_func based on the Component floattype
kernel_func should be a DeviceFunction object returned by cuda.jit
Returns a new DeviceFunction
'''
# disable float switching if in cudasim mode
if config.ENABLE_CUDASIM:
return propagate
return kernel_func

if not isinstance(propagate, DeviceFunction):
if not isinstance(kernel_func, DeviceFunction):
raise RuntimeError(
"invalid propagate function ({}, {}) registered, ".format(
propagate, type(propagate))
+ "does propagate have a signature defined?")
"invalid kernel function ({}, {}), ".format(
kernel_func, type(kernel_func))
+ "does the function have a signature defined?")

args = propagate.args
args = kernel_func.args

# reconstruct the numba args with the correct floattype
newargs = []
for arg in args:
if isinstance(arg, Array) and isinstance(arg.dtype, Float):
newargs.append(arg.copy(dtype=getattr(numba, cls._floattype)))
newargs.append(
arg.copy(dtype=getattr(numba, cls._floattype)))
elif isinstance(arg, Float):
newargs.append(Float(name=cls._floattype))
else:
Expand All @@ -156,19 +162,42 @@ def _adjust_propagate_type(cls, propagate):

# DeviceFunction in Numba < 0.54.1 does not have a lineinfo property
if int(numba.__version__.split(".")[1]) < 54:
new_propagate = DeviceFunction(pyfunc=propagate.py_func,
return_type=propagate.return_type,
args=newargs,
inline=propagate.inline,
debug=propagate.debug)
new_func = DeviceFunction(pyfunc=kernel_func.py_func,
return_type=kernel_func.return_type,
args=newargs,
inline=kernel_func.inline,
debug=kernel_func.debug)
else:
new_propagate = DeviceFunction(pyfunc=propagate.py_func,
return_type=propagate.return_type,
args=newargs,
inline=propagate.inline,
debug=propagate.debug,
lineinfo=propagate.lineinfo)
#cls.print_kernel_info(new_propagate)
new_func = DeviceFunction(pyfunc=kernel_func.py_func,
return_type=kernel_func.return_type,
args=newargs,
inline=kernel_func.inline,
debug=kernel_func.debug,
lineinfo=kernel_func.lineinfo)

return new_func

@classmethod
def _adjust_propagate_type(cls, propagate):
# disable float switching if in cudasim mode
if config.ENABLE_CUDASIM:
return propagate

if not isinstance(propagate, DeviceFunction):
raise RuntimeError(
"invalid propagate function ({}, {}) registered, ".format(
propagate, type(propagate))
+ "does propagate have a signature defined?")

# adjust float types of any device function that propagate calls
for func in propagate.py_func.__globals__:
if isinstance(propagate.py_func.__globals__[func], DeviceFunction):
propagate.py_func.__globals__[func] = \
cls._change_kernel_floattype(
propagate.py_func.__globals__[func])

new_propagate = cls._change_kernel_floattype(propagate)

return new_propagate

@classmethod
Expand Down
121 changes: 120 additions & 1 deletion tests/components/test_ComponentBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@
import pytest

import numba
from numba import cuda, void, float32
from numba import cuda, void, float32, float64
from numba.core.types import Array, Float
from mcvine.acc.components.ComponentBase import ComponentBase
#from mcvine.acc.neutron import absorb
from mcvine.acc import test


@cuda.jit(float64(float64), device=True)
def global_kernel(x):
return x

@cuda.jit(float64(float64), device=True)
def nested_global_kernel(x):
y = 1.0 + global_kernel(x)
return y


def test_no_propagate_raises():
with pytest.raises(TypeError):
# check creating a component with no propagate method raises error
Expand Down Expand Up @@ -80,3 +91,111 @@ def propagate(x, y):
assert isinstance(args[1], Array)
assert args[1].dtype == float32


@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA')
def test_propagate_global_function_changed():
# check that propagate arguments are changed from float64 -> float32
NB_FLOAT = getattr(numba, "float64")
class Component(ComponentBase):
def __init__(self, **kwargs):
return

@cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True)
def propagate(x, y):
y[0] = global_kernel(x)

component = Component()
Component.change_floattype("float32")
assert component.floattype == "float32"

# check that the class wide attributes are changed
assert Component.get_floattype() == "float32"
assert Component.process_kernel is not None
args = Component.propagate.args
assert len(args) == 2

assert isinstance(args[0], Float)
assert args[0].bitwidth == 32
assert isinstance(args[1], Array)
assert args[1].dtype == float32

# check that the global kernel function args are changed
args = global_kernel.args
assert len(args) == 1

assert isinstance(args[0], Float)
assert args[0].bitwidth == 32


@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA')
def test_propagate_nested_global_function_changed():
# check that propagate arguments are changed from float64 -> float32
NB_FLOAT = getattr(numba, "float64")
class Component(ComponentBase):
def __init__(self, **kwargs):
return

@cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True)
def propagate(x, y):
y[0] = nested_global_kernel(x)

component = Component()
Component.change_floattype("float32")
assert component.floattype == "float32"

# check that the class wide attributes are changed
assert Component.get_floattype() == "float32"
assert Component.process_kernel is not None
args = Component.propagate.args
assert len(args) == 2

assert isinstance(args[0], Float)
assert args[0].bitwidth == 32
assert isinstance(args[1], Array)
assert args[1].dtype == float32

# check that the nested kernel function args are changed
args = nested_global_kernel.args
assert len(args) == 1

assert isinstance(args[0], Float)
assert args[0].bitwidth == 32


@pytest.mark.skipif(not test.USE_CUDA, reason='No CUDA')
def test_propagate_local_function_changed():
# check that propagate arguments are changed from float64 -> float32
NB_FLOAT = getattr(numba, "float64")
@cuda.jit(NB_FLOAT(NB_FLOAT, NB_FLOAT), device=True)
def helper_kernel(x, y):
return x * y

class Component(ComponentBase):
def __init__(self, **kwargs):
return

@cuda.jit(void(NB_FLOAT, NB_FLOAT[:]), device=True)
def propagate(x, y):
y[0] = helper_kernel(x, x)

component = Component()
Component.change_floattype("float32")
assert component.floattype == "float32"

# check that the class wide attributes are changed
assert Component.get_floattype() == "float32"
assert Component.process_kernel is not None
args = Component.propagate.args
assert len(args) == 2

assert isinstance(args[0], Float)
assert args[0].bitwidth == 32
assert isinstance(args[1], Array)
assert args[1].dtype == float32

# check that the local kernel function args are changed
args = helper_kernel.args
assert len(args) == 2
for arg in args:
assert isinstance(arg, Float)
assert arg.bitwidth == 32