Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 059efbc

Browse files
authoredMay 9, 2022
[SYCL] Add half non-assign math operators (#6061)
This PR adds missing half operations +-*/. There are existing operations for the legacy class host_half_impl, however these were not extended to the half class. These operations were being performed as floating point operations via the implicit floating conversion. This results in the output being a float not half type. The template is limited to arithmetic types to prevent ambiguous templating. A minor change to built-in __fract is needed to ensure fmin is not ambiguous. llvm-test-suite PR: intel/llvm-test-suite#1012 Fixes: #6028
1 parent 21d62dd commit 059efbc

File tree

3 files changed

+297
-9
lines changed

3 files changed

+297
-9
lines changed
 

‎sycl/include/CL/sycl/half_type.hpp

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,200 @@ class half {
386386
operator--();
387387
return ret;
388388
}
389-
constexpr half &operator-() {
390-
Data = -Data;
391-
return *this;
392-
}
393-
constexpr half operator-() const {
394-
half r = *this;
395-
return -r;
396-
}
389+
__SYCL_CONSTEXPR_HALF friend half operator-(const half other) {
390+
return half(-other.Data);
391+
}
392+
393+
// Operator +, -, *, /
394+
#define OP(op, op_eq) \
395+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
396+
const half rhs) { \
397+
half rtn = lhs; \
398+
rtn op_eq rhs; \
399+
return rtn; \
400+
} \
401+
__SYCL_CONSTEXPR_HALF friend double operator op(const half lhs, \
402+
const double rhs) { \
403+
double rtn = lhs; \
404+
rtn op_eq rhs; \
405+
return rtn; \
406+
} \
407+
__SYCL_CONSTEXPR_HALF friend double operator op(const double lhs, \
408+
const half rhs) { \
409+
double rtn = lhs; \
410+
rtn op_eq rhs; \
411+
return rtn; \
412+
} \
413+
__SYCL_CONSTEXPR_HALF friend float operator op(const half lhs, \
414+
const float rhs) { \
415+
float rtn = lhs; \
416+
rtn op_eq rhs; \
417+
return rtn; \
418+
} \
419+
__SYCL_CONSTEXPR_HALF friend float operator op(const float lhs, \
420+
const half rhs) { \
421+
float rtn = lhs; \
422+
rtn op_eq rhs; \
423+
return rtn; \
424+
} \
425+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
426+
const int rhs) { \
427+
half rtn = lhs; \
428+
rtn op_eq rhs; \
429+
return rtn; \
430+
} \
431+
__SYCL_CONSTEXPR_HALF friend half operator op(const int lhs, \
432+
const half rhs) { \
433+
half rtn = lhs; \
434+
rtn op_eq rhs; \
435+
return rtn; \
436+
} \
437+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
438+
const long rhs) { \
439+
half rtn = lhs; \
440+
rtn op_eq rhs; \
441+
return rtn; \
442+
} \
443+
__SYCL_CONSTEXPR_HALF friend half operator op(const long lhs, \
444+
const half rhs) { \
445+
half rtn = lhs; \
446+
rtn op_eq rhs; \
447+
return rtn; \
448+
} \
449+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
450+
const long long rhs) { \
451+
half rtn = lhs; \
452+
rtn op_eq rhs; \
453+
return rtn; \
454+
} \
455+
__SYCL_CONSTEXPR_HALF friend half operator op(const long long lhs, \
456+
const half rhs) { \
457+
half rtn = lhs; \
458+
rtn op_eq rhs; \
459+
return rtn; \
460+
} \
461+
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
462+
const unsigned int &rhs) { \
463+
half rtn = lhs; \
464+
rtn op_eq rhs; \
465+
return rtn; \
466+
} \
467+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned int &lhs, \
468+
const half &rhs) { \
469+
half rtn = lhs; \
470+
rtn op_eq rhs; \
471+
return rtn; \
472+
} \
473+
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
474+
const unsigned long &rhs) { \
475+
half rtn = lhs; \
476+
rtn op_eq rhs; \
477+
return rtn; \
478+
} \
479+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long &lhs, \
480+
const half &rhs) { \
481+
half rtn = lhs; \
482+
rtn op_eq rhs; \
483+
return rtn; \
484+
} \
485+
__SYCL_CONSTEXPR_HALF friend half operator op( \
486+
const half &lhs, const unsigned long long &rhs) { \
487+
half rtn = lhs; \
488+
rtn op_eq rhs; \
489+
return rtn; \
490+
} \
491+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long long &lhs, \
492+
const half &rhs) { \
493+
half rtn = lhs; \
494+
rtn op_eq rhs; \
495+
return rtn; \
496+
}
497+
OP(+, +=)
498+
OP(-, -=)
499+
OP(*, *=)
500+
OP(/, /=)
501+
502+
#undef OP
503+
504+
// Operator ==, !=, <, >, <=, >=
505+
#define OP(op) \
506+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
507+
const half &rhs) { \
508+
return lhs.Data op rhs.Data; \
509+
} \
510+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
511+
const double &rhs) { \
512+
return lhs.Data op rhs; \
513+
} \
514+
__SYCL_CONSTEXPR_HALF friend bool operator op(const double &lhs, \
515+
const half &rhs) { \
516+
return lhs op rhs.Data; \
517+
} \
518+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
519+
const float &rhs) { \
520+
return lhs.Data op rhs; \
521+
} \
522+
__SYCL_CONSTEXPR_HALF friend bool operator op(const float &lhs, \
523+
const half &rhs) { \
524+
return lhs op rhs.Data; \
525+
} \
526+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
527+
const int &rhs) { \
528+
return lhs.Data op rhs; \
529+
} \
530+
__SYCL_CONSTEXPR_HALF friend bool operator op(const int &lhs, \
531+
const half &rhs) { \
532+
return lhs op rhs.Data; \
533+
} \
534+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
535+
const long &rhs) { \
536+
return lhs.Data op rhs; \
537+
} \
538+
__SYCL_CONSTEXPR_HALF friend bool operator op(const long &lhs, \
539+
const half &rhs) { \
540+
return lhs op rhs.Data; \
541+
} \
542+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
543+
const long long &rhs) { \
544+
return lhs.Data op rhs; \
545+
} \
546+
__SYCL_CONSTEXPR_HALF friend bool operator op(const long long &lhs, \
547+
const half &rhs) { \
548+
return lhs op rhs.Data; \
549+
} \
550+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
551+
const unsigned int &rhs) { \
552+
return lhs.Data op rhs; \
553+
} \
554+
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned int &lhs, \
555+
const half &rhs) { \
556+
return lhs op rhs.Data; \
557+
} \
558+
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
559+
const unsigned long &rhs) { \
560+
return lhs.Data op rhs; \
561+
} \
562+
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned long &lhs, \
563+
const half &rhs) { \
564+
return lhs op rhs.Data; \
565+
} \
566+
__SYCL_CONSTEXPR_HALF friend bool operator op( \
567+
const half &lhs, const unsigned long long &rhs) { \
568+
return lhs.Data op rhs; \
569+
} \
570+
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned long long &lhs, \
571+
const half &rhs) { \
572+
return lhs op rhs.Data; \
573+
}
574+
OP(==)
575+
OP(!=)
576+
OP(<)
577+
OP(>)
578+
OP(<=)
579+
OP(>=)
580+
581+
#undef OP
582+
397583
// Operator float
398584
__SYCL_CONSTEXPR_HALF operator float() const {
399585
return static_cast<float>(Data);

‎sycl/source/detail/builtins_math.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ template <typename T> inline T __cospi(T x) { return std::cos(M_PI * x); }
4242
template <typename T> T inline __fract(T x, T *iptr) {
4343
T f = std::floor(x);
4444
*(iptr) = f;
45-
return std::fmin(x - f, nextafter(T(1.0), T(0.0)));
45+
return std::fmin(x - f, std::nextafter(T(1.0), T(0.0)));
4646
}
4747

4848
template <typename T> inline T __lgamma_r(T x, s::cl_int *signp) {
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
//==-------------- type_traits.cpp - SYCL type_traits test -----------------==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <CL/sycl.hpp>
11+
using namespace std;
12+
13+
template <typename T1, typename T_rtn> void math_operator_helper() {
14+
static_assert(
15+
is_same_v<decltype(declval<T1>() + declval<sycl::half>()), T_rtn>);
16+
static_assert(
17+
is_same_v<decltype(declval<T1>() - declval<sycl::half>()), T_rtn>);
18+
static_assert(
19+
is_same_v<decltype(declval<T1>() * declval<sycl::half>()), T_rtn>);
20+
static_assert(
21+
is_same_v<decltype(declval<T1>() / declval<sycl::half>()), T_rtn>);
22+
23+
static_assert(
24+
is_same_v<decltype(declval<sycl::half>() + declval<T1>()), T_rtn>);
25+
static_assert(
26+
is_same_v<decltype(declval<sycl::half>() - declval<T1>()), T_rtn>);
27+
static_assert(
28+
is_same_v<decltype(declval<sycl::half>() * declval<T1>()), T_rtn>);
29+
static_assert(
30+
is_same_v<decltype(declval<sycl::half>() / declval<T1>()), T_rtn>);
31+
}
32+
33+
template <typename T1> void logical_operator_helper() {
34+
static_assert(
35+
is_same_v<decltype(declval<T1>() == declval<sycl::half>()), bool>);
36+
static_assert(
37+
is_same_v<decltype(declval<T1>() != declval<sycl::half>()), bool>);
38+
static_assert(
39+
is_same_v<decltype(declval<T1>() > declval<sycl::half>()), bool>);
40+
static_assert(
41+
is_same_v<decltype(declval<T1>() < declval<sycl::half>()), bool>);
42+
static_assert(
43+
is_same_v<decltype(declval<T1>() <= declval<sycl::half>()), bool>);
44+
static_assert(
45+
is_same_v<decltype(declval<T1>() >= declval<sycl::half>()), bool>);
46+
47+
static_assert(
48+
is_same_v<decltype(declval<sycl::half>() == declval<T1>()), bool>);
49+
static_assert(
50+
is_same_v<decltype(declval<sycl::half>() != declval<T1>()), bool>);
51+
static_assert(
52+
is_same_v<decltype(declval<sycl::half>() > declval<T1>()), bool>);
53+
static_assert(
54+
is_same_v<decltype(declval<sycl::half>() < declval<T1>()), bool>);
55+
static_assert(
56+
is_same_v<decltype(declval<sycl::half>() <= declval<T1>()), bool>);
57+
static_assert(
58+
is_same_v<decltype(declval<sycl::half>() >= declval<T1>()), bool>);
59+
}
60+
61+
template <typename T1, typename T_rtn>
62+
void check_half_math_operator_types(sycl::queue &Queue) {
63+
64+
// Test on host
65+
math_operator_helper<T1, T_rtn>();
66+
67+
// Test on device
68+
Queue.submit([&](sycl::handler &cgh) {
69+
cgh.single_task([=] { math_operator_helper<T1, T_rtn>(); });
70+
});
71+
}
72+
73+
template <typename T1>
74+
void check_half_logical_operator_types(sycl::queue &Queue) {
75+
76+
// Test on host
77+
logical_operator_helper<T1>();
78+
79+
// Test on device
80+
Queue.submit([&](sycl::handler &cgh) {
81+
cgh.single_task([=] { logical_operator_helper<T1>(); });
82+
});
83+
}
84+
85+
int main() {
86+
87+
sycl::queue Queue;
88+
89+
check_half_math_operator_types<sycl::half, sycl::half>(Queue);
90+
check_half_math_operator_types<double, double>(Queue);
91+
check_half_math_operator_types<float, float>(Queue);
92+
check_half_math_operator_types<int, sycl::half>(Queue);
93+
check_half_math_operator_types<long, sycl::half>(Queue);
94+
check_half_math_operator_types<long long, sycl::half>(Queue);
95+
96+
check_half_logical_operator_types<sycl::half>(Queue);
97+
check_half_logical_operator_types<double>(Queue);
98+
check_half_logical_operator_types<float>(Queue);
99+
check_half_logical_operator_types<int>(Queue);
100+
check_half_logical_operator_types<long>(Queue);
101+
check_half_logical_operator_types<long long>(Queue);
102+
}

0 commit comments

Comments
 (0)
Please sign in to comment.