Skip to content

Commit f232cfd

Browse files
authored
[ESIMD] Enforce compile-time channel mask restrictions for rgba write APIs. (#6137)
Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 90ac3ee commit f232cfd

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,23 @@ gather_rgba(const Tx *p, simd<uint32_t, N> offsets, simd_mask<N> mask = 1) {
529529
return __esimd_svm_gather4_scaled<T, N, Mask>(addrs.data(), mask.data());
530530
}
531531

532+
namespace detail {
533+
template <rgba_channel_mask M> static void validate_rgba_write_channel_mask() {
534+
using CM = rgba_channel_mask;
535+
static_assert(
536+
(M == CM::ABGR || M == CM::BGR || M == CM::GR || M == CM::R) &&
537+
"Only ABGR, BGR, GR, R channel masks are valid in write operations");
538+
}
539+
} // namespace detail
540+
532541
/// @anchor usm_scatter_rgba
533542
/// Transpose and scatter pixels to given memory locations defined by the base
534543
/// pointer \c p and \c offsets. Up to 4 32-bit data elements may be accessed at
535544
/// each address depending on the channel mask \c Mask template parameter. Each
536545
/// pixel's address must be 4 byte aligned. This is basically an inverse
537-
/// operation for gather_rgba.
546+
/// operation for gather_rgba. Unlike \c gather_rgba, this function imposes
547+
/// restrictions on possible \c Mask template argument values. It can only be
548+
/// one of the following: \c ABGR, \c BGR, \c GR, \c R.
538549
///
539550
/// @tparam Tx Element type of the returned vector. Must be 4 bytes in size.
540551
/// @tparam N Number of pixels to access (matches the size of the \c offsets
@@ -553,6 +564,7 @@ __ESIMD_API std::enable_if_t<(N == 8 || N == 16 || N == 32) && (sizeof(T) == 4)>
553564
scatter_rgba(Tx *p, simd<uint32_t, N> offsets,
554565
simd<Tx, N * get_num_channels_enabled(Mask)> vals,
555566
simd_mask<N> mask = 1) {
567+
detail::validate_rgba_write_channel_mask<Mask>();
556568
simd<uint64_t, N> offsets_i = convert<uint64_t>(offsets);
557569
simd<uint64_t, N> addrs(reinterpret_cast<uint64_t>(p));
558570
addrs = addrs + offsets_i;
@@ -875,7 +887,7 @@ slm_gather_rgba(simd<uint32_t, N> offsets, simd_mask<N> mask = 1) {
875887
}
876888

877889
/// Gather data from the Shared Local Memory at specified \c offsets and return
878-
/// it as simd vector. See @ref usm_gather_rgba for information about the
890+
/// it as simd vector. See @ref usm_scatter_rgba for information about the
879891
/// operation semantics and parameter restrictions/interdependencies.
880892
/// @tparam T The element type of the returned vector.
881893
/// @tparam N The number of elements to access.
@@ -889,6 +901,7 @@ __ESIMD_API std::enable_if_t<(N == 8 || N == 16 || N == 32) && (sizeof(T) == 4)>
889901
slm_scatter_rgba(simd<uint32_t, N> offsets,
890902
simd<T, N * get_num_channels_enabled(Mask)> vals,
891903
simd_mask<N> mask = 1) {
904+
detail::validate_rgba_write_channel_mask<Mask>();
892905
const auto si = __ESIMD_GET_SURF_HANDLE(detail::LocalAccessorMarker());
893906
constexpr int16_t Scale = 0;
894907
constexpr int global_offset = 0;

sycl/test/esimd/gather_scatter_rgba.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
// RUN: %clangxx -fsycl -fsyntax-only -Wno-unused-command-line-argument %s
1+
// RUN: %clangxx -fsycl -fsycl-device-only -fsyntax-only -Xclang -verify %s
22

3-
// This test checks compilation of ESIMD slm gather_rgba/scatter_rgba APIs.
3+
// This test checks that device compiler can:
4+
// - successfully compile gather_rgba/scatter_rgba APIs
5+
// - emit an error if some of the restrictions on template parameters are
6+
// violated
47

58
#include <CL/sycl.hpp>
69
#include <limits>
@@ -20,3 +23,15 @@ void kernel(int *ptr) SYCL_ESIMD_FUNCTION {
2023

2124
scatter_rgba<int, 32, rgba_channel_mask::ABGR>(ptr, offsets, v0);
2225
}
26+
27+
constexpr int AGR_N_CHANNELS = 3;
28+
29+
void kernel1(int *ptr, simd<int, 32 * AGR_N_CHANNELS> v) SYCL_ESIMD_FUNCTION {
30+
simd<uint32_t, 32> offsets(0, sizeof(int) * 4);
31+
// only 1, 2, 3, 4-element masks covering consequitive channels starting from
32+
// R are supported
33+
// expected-error-re@* {{static_assert failed{{.*}}Only ABGR, BGR, GR, R channel masks are valid in write operations}}
34+
// expected-note@* {{in instantiation }}
35+
// expected-note@+1 {{in instantiation }}
36+
scatter_rgba<int, 32, rgba_channel_mask::AGR>(ptr, offsets, v);
37+
}

0 commit comments

Comments
 (0)