Skip to content

Commit fb8d4ab

Browse files
[NFC][SYCL] Cleanup marray/vector includes (#20181)
Highlights: * No `<exception>` for device compilation * limit `half`/`bfloat16` to forward declarations * `is_device_copyable.hpp` is lightweight but it triggers an FE bug when using PCH for device code, so I limited it to forward declaration for now too. * Remove duplicate/redundant includes
1 parent cbed756 commit fb8d4ab

File tree

5 files changed

+19
-39
lines changed

5 files changed

+19
-39
lines changed

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111
#include <sycl/aliases.hpp>
1212
#include <sycl/detail/generic_type_traits.hpp>
13-
#include <sycl/detail/type_traits.hpp>
14-
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
15-
#include <sycl/ext/oneapi/bfloat16.hpp>
1613

1714
#include <functional>
1815

sycl/include/sycl/detail/vector_convert.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@
5454

5555
#pragma once
5656

57-
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
58-
#include <sycl/exception.hpp> // for errc
57+
#include <sycl/detail/generic_type_traits.hpp>
5958

6059
#include <sycl/detail/memcpy.hpp>
6160
#include <sycl/ext/oneapi/bfloat16.hpp>
6261
#include <sycl/half_type.hpp>
6362
#include <sycl/vector.hpp>
6463

6564
#ifndef __SYCL_DEVICE_ONLY__
66-
#include <cfenv> // for fesetround, fegetround
65+
#include <cfenv>
66+
#include <sycl/exception.hpp>
6767
#endif
6868

6969
#include <type_traits>

sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ ldg(const T *ptr) {
367367
} else if constexpr (std::is_same_v<T, sycl::vec<half, 2>>) {
368368
typedef __fp16 h2 ATTRIBUTE_EXT_VEC_TYPE(2);
369369
auto rv = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
370-
sycl::vec<half, 2> ret;
370+
T ret;
371371
ret.x() = rv[0];
372372
ret.y() = rv[1];
373373
return ret;
@@ -376,7 +376,7 @@ ldg(const T *ptr) {
376376
h2 rv_2 = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
377377
auto rv = __nvvm_ldg_h(reinterpret_cast<const __fp16 *>(
378378
std::next(reinterpret_cast<const h2 *>(ptr))));
379-
sycl::vec<half, 3> ret;
379+
T ret;
380380
ret.x() = rv_2[0];
381381
ret.y() = rv_2[1];
382382
ret.z() = rv;
@@ -385,7 +385,7 @@ ldg(const T *ptr) {
385385
typedef __fp16 h2 ATTRIBUTE_EXT_VEC_TYPE(2);
386386
auto rv1 = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
387387
auto rv2 = __nvvm_ldg_h2(std::next(reinterpret_cast<const h2 *>(ptr)));
388-
sycl::vec<half, 4> ret;
388+
T ret;
389389
ret.x() = rv1[0];
390390
ret.y() = rv1[1];
391391
ret.z() = rv2[0];

sycl/include/sycl/marray.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,15 @@
1010

1111
#include <sycl/aliases.hpp>
1212
#include <sycl/detail/common.hpp>
13-
#include <sycl/detail/is_device_copyable.hpp>
14-
#include <sycl/half_type.hpp>
15-
16-
#include <array>
17-
#include <cstddef>
18-
#include <cstdint>
19-
#include <type_traits>
20-
#include <utility>
13+
#include <sycl/detail/fwd/half.hpp>
2114

2215
namespace sycl {
2316
inline namespace _V1 {
2417

2518
template <typename DataT, std::size_t N> class marray;
2619

20+
template <typename T> struct is_device_copyable;
21+
2722
namespace detail {
2823

2924
// Helper trait for counting the aggregate number of arguments in a type list,

sycl/include/sycl/vector.hpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,16 @@
3131
#error "SYCL device compiler is built without ext_vector_type support"
3232
#endif
3333

34-
#include <sycl/access/access.hpp> // for decorated, address_space
35-
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
36-
#include <sycl/detail/common.hpp> // for ArrayCreator
37-
#include <sycl/detail/defines_elementary.hpp> // for __SYCL2020_DEPRECATED
38-
#include <sycl/detail/fwd/accessor.hpp>
39-
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
40-
#include <sycl/detail/memcpy.hpp> // for memcpy
4134
#include <sycl/detail/named_swizzles_mixin.hpp>
42-
#include <sycl/detail/type_traits.hpp> // for is_floating_point
4335
#include <sycl/detail/vector_arith.hpp>
44-
#include <sycl/half_type.hpp> // for StorageT, half, Vec16...
4536

46-
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
37+
#include <sycl/detail/common.hpp>
38+
#include <sycl/detail/fwd/accessor.hpp>
39+
#include <sycl/detail/fwd/half.hpp>
40+
#include <sycl/detail/memcpy.hpp>
4741

48-
#include <algorithm> // for std::min
49-
#include <array> // for array
50-
#include <cassert> // for assert
51-
#include <cstddef> // for size_t, NULL, byte
52-
#include <cstdint> // for uint8_t, int16_t, int...
53-
#include <functional> // for divides, multiplies
54-
#include <iterator> // for pair
55-
#include <ostream> // for operator<<, basic_ost...
56-
#include <type_traits> // for enable_if_t, is_same
57-
#include <utility> // for index_sequence, make_...
42+
#include <algorithm>
43+
#include <functional>
5844

5945
namespace sycl {
6046

@@ -63,6 +49,9 @@ namespace sycl {
6349
enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
6450

6551
inline namespace _V1 {
52+
namespace ext::oneapi {
53+
class bfloat16;
54+
}
6655

6756
struct elem {
6857
static constexpr int x = 0;
@@ -512,8 +501,7 @@ class __SYCL_EBO vec :
512501
#endif
513502
bool, /*->*/ std::uint8_t, //
514503
sycl::half, /*->*/ sycl::detail::half_impl::StorageT, //
515-
sycl::ext::oneapi::bfloat16,
516-
/*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, //
504+
sycl::ext::oneapi::bfloat16, /*->*/ uint16_t, //
517505
char, /*->*/ detail::ConvertToOpenCLType_t<char>, //
518506
DataT, /*->*/ DataT //
519507
>::type;

0 commit comments

Comments
 (0)