Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cache-hinted load and store operations #51

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions numba_cuda/numba/cuda/api_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from numba import types
from numba.core import cgutils
import numpy as np


Expand Down Expand Up @@ -28,3 +30,26 @@ def _fill_stride_by_order(shape, dtype, order):
else:
raise ValueError('must be either C/F order')
return tuple(strides)


def normalize_indices(context, builder, indty, inds, aryty, valty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [context.cast(builder, i, t, types.intp)
for t, i in zip(indty, indices)]

dtype = aryty.dtype
if dtype != valty:
raise TypeError("expect %s but got %s" % (dtype, valty))

if aryty.ndim != len(indty):
raise TypeError("indexing %d-D array with %d-D index" %
(aryty.ndim, len(indty)))

return indty, indices
241 changes: 241 additions & 0 deletions numba_cuda/numba/cuda/cache_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from llvmlite import ir
from numba import types
from numba.core import cgutils
from numba.core.extending import intrinsic, overload
from numba.core.errors import NumbaTypeError
from numba.cuda.api_util import normalize_indices

# Docs references:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints


def ldca(array, i):
"""Generate a `ld.global.ca` instruction for element `i` of an array."""


def ldcg(array, i):
"""Generate a `ld.global.cg` instruction for element `i` of an array."""


def ldcs(array, i):
"""Generate a `ld.global.cs` instruction for element `i` of an array."""


def ldlu(array, i):
"""Generate a `ld.global.lu` instruction for element `i` of an array."""


def ldcv(array, i):
"""Generate a `ld.global.cv` instruction for element `i` of an array."""


def stcg(array, i, value):
"""Generate a `st.global.cg` instruction for element `i` of an array."""


def stcs(array, i, value):
"""Generate a `st.global.cs` instruction for element `i` of an array."""


def stwb(array, i, value):
"""Generate a `st.global.wb` instruction for element `i` of an array."""


def stwt(array, i, value):
"""Generate a `st.global.wt` instruction for element `i` of an array."""


# See
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes
# for background on the choice of "r" for 8-bit operands - there is
# no constraint for 8-bit operands, but the operand for loads and
# stores is permitted to be greater than 8 bits.
CONSTRAINT_MAP = {
1: "b",
8: "r",
16: "h",
32: "r",
64: "l",
128: "q"
}


def _validate_arguments(instruction, array, index):
if not isinstance(array, types.Array):
msg = f"{instruction} operates on arrays. Got type {array}"
raise NumbaTypeError(msg)

valid_index = False

if isinstance(index, types.Integer):
if array.ndim != 1:
msg = f"Expected {array.ndim} indices, got a scalar"
raise NumbaTypeError(msg)
valid_index = True

if isinstance(index, types.UniTuple):
if index.count != array.ndim:
msg = f"Expected {array.ndim} indices, got {index.count}"
raise NumbaTypeError(msg)

if all([isinstance(t, types.Integer) for t in index.dtype]):
valid_index = True

if not valid_index:
raise NumbaTypeError(f"{index} is not a valid index")


def ld_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index):
_validate_arguments(f"ld{operator}", array, index)

# Need to validate bitwidth

signature = array.dtype(array, index)

def codegen(context, builder, sig, args):
array_type, index_type = sig.args
loaded_type = context.get_value_type(array_type.dtype)
ptr_type = loaded_type.as_pointer()
ldcs_type = ir.FunctionType(loaded_type, [ptr_type])

array, indices = args

index_type, indices = normalize_indices(context, builder,
index_type, indices,
array_type,
array_type.dtype)
array_struct = context.make_array(array_type)(context, builder,
value=array)
ptr = cgutils.get_item_pointer(context, builder, array_type,
array_struct, indices,
wraparound=True)

bitwidth = array_type.dtype.bitwidth
inst = f"ld.global.{operator}.b{bitwidth}"
constraints = f"={CONSTRAINT_MAP[bitwidth]},l"
ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints)
return builder.call(ldcs, [ptr])

return signature, codegen

return impl


ldca_intrinsic = ld_cache_operator("ca")
ldcg_intrinsic = ld_cache_operator("cg")
ldcs_intrinsic = ld_cache_operator("cs")
ldlu_intrinsic = ld_cache_operator("lu")
ldcv_intrinsic = ld_cache_operator("cv")


def st_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index, value):
_validate_arguments(f"st{operator}", array, index)

# Need to validate bitwidth

signature = types.void(array, index, value)

def codegen(context, builder, sig, args):
array_type, index_type, value_type = sig.args
stored_type = context.get_value_type(array_type.dtype)
ptr_type = stored_type.as_pointer()
stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type])

array, indices, value = args

index_type, indices = normalize_indices(context, builder,
index_type, indices,
array_type,
array_type.dtype)
array_struct = context.make_array(array_type)(context, builder,
value=array)
ptr = cgutils.get_item_pointer(context, builder, array_type,
array_struct, indices,
wraparound=True)

casted_value = context.cast(builder, value, value_type,
array_type.dtype)

bitwidth = array_type.dtype.bitwidth
inst = f"st.global.{operator}.b{bitwidth}"
constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}"
stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints)
builder.call(stcs, [ptr, casted_value])

return signature, codegen

return impl


stcg_intrinsic = st_cache_operator("cg")
stcs_intrinsic = st_cache_operator("cs")
stwb_intrinsic = st_cache_operator("wb")
stwt_intrinsic = st_cache_operator("wt")


@overload(ldca, target='cuda')
def ol_ldca(array, i):
def impl(array, i):
return ldca_intrinsic(array, i)
return impl


@overload(ldcg, target='cuda')
def ol_ldcg(array, i):
def impl(array, i):
return ldcg_intrinsic(array, i)
return impl


@overload(ldcs, target='cuda')
def ol_ldcs(array, i):
def impl(array, i):
return ldcs_intrinsic(array, i)
return impl


@overload(ldlu, target='cuda')
def ol_ldlu(array, i):
def impl(array, i):
return ldlu_intrinsic(array, i)
return impl


@overload(ldcv, target='cuda')
def ol_ldcv(array, i):
def impl(array, i):
return ldcv_intrinsic(array, i)
return impl


@overload(stcg, target='cuda')
def ol_stcg(array, i, value):
def impl(array, i, value):
return stcg_intrinsic(array, i, value)
return impl


@overload(stcs, target='cuda')
def ol_stcs(array, i, value):
def impl(array, i, value):
return stcs_intrinsic(array, i, value)
return impl


@overload(stwb, target='cuda')
def ol_stwb(array, i, value):
def impl(array, i, value):
return stwb_intrinsic(array, i, value)
return impl


@overload(stwt, target='cuda')
def ol_stwt(array, i, value):
def impl(array, i, value):
return stwt_intrinsic(array, i, value)
return impl
32 changes: 5 additions & 27 deletions numba_cuda/numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba.np.npyimpl import register_ufuncs
from .cudadrv import nvvm
from numba import cuda
from numba.cuda.api_util import normalize_indices
from numba.cuda import nvvmutils, stubs, errors
from numba.cuda.types import dim3, CUDADispatcher

Expand Down Expand Up @@ -692,38 +693,15 @@ def impl(context, builder, sig, args):
lower(math.degrees, types.f8)(gen_deg_rad(_rad2deg))


def _normalize_indices(context, builder, indty, inds, aryty, valty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [context.cast(builder, i, t, types.intp)
for t, i in zip(indty, indices)]

dtype = aryty.dtype
if dtype != valty:
raise TypeError("expect %s but got %s" % (dtype, valty))

if aryty.ndim != len(indty):
raise TypeError("indexing %d-D array with %d-D index" %
(aryty.ndim, len(indty)))

return indty, indices


def _atomic_dispatcher(dispatch_fn):
def imp(context, builder, sig, args):
# The common argument handling code
aryty, indty, valty = sig.args
ary, inds, val = args
dtype = aryty.dtype

indty, indices = _normalize_indices(context, builder, indty, inds,
aryty, valty)
indty, indices = normalize_indices(context, builder, indty, inds,
aryty, valty)

lary = context.make_array(aryty)(context, builder, ary)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
Expand Down Expand Up @@ -917,8 +895,8 @@ def ptx_atomic_cas(context, builder, sig, args):
aryty, indty, oldty, valty = sig.args
ary, inds, old, val = args

indty, indices = _normalize_indices(context, builder, indty, inds, aryty,
valty)
indty, indices = normalize_indices(context, builder, indty, inds, aryty,
valty)

lary = context.make_array(aryty)(context, builder, ary)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
Expand Down
2 changes: 2 additions & 0 deletions numba_cuda/numba/cuda/device_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Re export
import sys
from numba.cuda import cg
from numba.cuda.cache_hints import (ldca, ldcg, ldcs, ldlu, ldcv, stcg, stcs,
stwb, stwt)
from .stubs import (threadIdx, blockIdx, blockDim, gridDim, laneid, warpsize,
syncwarp, shared, local, const, atomic,
shfl_sync_intrinsic, vote_sync_intrinsic, match_any_sync,
Expand Down
Loading