11#pragma once
22#include < cassert>
33#include < climits>
4+ #include < cmath>
45#include < cstdint>
56#include < initializer_list>
67#include < iostream>
@@ -51,6 +52,58 @@ template <> class imf_utils_default_equ<uint64_t> {
5152 }
5253};
5354
55+ template <class Ty > class imf_utils_fp_equ {
56+ public:
57+ bool operator ()(Ty x, Ty y) { return x == y; }
58+ };
59+
60+ template <> class imf_utils_fp_equ <float > {
61+ public:
62+ bool operator ()(float x, float y) {
63+ if ((__builtin_isinf_sign (x) * __builtin_isinf_sign (y)) == 1 )
64+ return true ;
65+ if (__builtin_isnan (x) || __builtin_isnan (y))
66+ return false ;
67+ // Simple check for 2 fp32
68+ const float relative_eps = 1e-4f ;
69+ return std::fabs (x - y) <
70+ relative_eps * std::fmax (std::fabs (x), std::fabs (y));
71+ }
72+ };
73+
74+ template <> class imf_utils_fp_equ <sycl::half> {
75+ public:
76+ bool operator ()(sycl::half x, sycl::half y) {
77+ float xf = static_cast <float >(x);
78+ float yf = static_cast <float >(y);
79+ if ((__builtin_isinf_sign (xf) * __builtin_isinf_sign (yf)) == 1 )
80+ return true ;
81+ if (__builtin_isnan (xf) || __builtin_isnan (yf))
82+ return false ;
83+ // Simple check for 2 fp16
84+ const float relative_eps = 1e-3f ;
85+ return std::fabs (xf - yf) <
86+ relative_eps * std::fmax (std::fabs (xf), std::fabs (yf));
87+ }
88+ };
89+
90+ template <> class imf_utils_fp_equ <sycl::ext::oneapi::bfloat16> {
91+ public:
92+ bool operator ()(sycl::ext::oneapi::bfloat16 x,
93+ sycl::ext::oneapi::bfloat16 y) {
94+ float xf = static_cast <float >(x);
95+ float yf = static_cast <float >(y);
96+ if ((__builtin_isinf_sign (xf) * __builtin_isinf_sign (yf)) == 1 )
97+ return true ;
98+ if (__builtin_isnan (xf) || __builtin_isnan (yf))
99+ return false ;
100+ // Simple check for 2 bf16
101+ const float relative_eps = 1e-3f ;
102+ return std::fabs (xf - yf) <
103+ relative_eps * std::fmax (std::fabs (xf), std::fabs (yf));
104+ }
105+ };
106+
54107// Used to test half precision utils
55108template <class InputTy , class OutputTy , class FuncTy ,
56109 class EquTy = imf_utils_default_equ<OutputTy>>
@@ -72,6 +125,42 @@ void test_host(std::initializer_list<InputTy> Input,
72125 }
73126}
74127
128+ template <class InputTy , class FuncTy , class EquTy = imf_utils_fp_equ<InputTy>>
129+ void test (sycl::queue &q, std::initializer_list<InputTy> Input, FuncTy Func,
130+ int Line = __builtin_LINE()) {
131+ auto Size = Input.size ();
132+ std::vector<InputTy> HostRef (Size);
133+ for (size_t Idx = 0 ; Idx < Size; ++Idx) {
134+ HostRef[Idx] = Func (*(std::begin (Input) + Idx));
135+ }
136+
137+ sycl::buffer<InputTy> InBuf (Size);
138+ {
139+ sycl::host_accessor InAcc (InBuf, sycl::write_only);
140+ int i = 0 ;
141+ for (auto x : Input)
142+ InAcc[i++] = x;
143+ }
144+
145+ sycl::buffer<InputTy> OutBuf (Size);
146+ q.submit ([&](sycl::handler &CGH) {
147+ sycl::accessor InAcc (InBuf, CGH, sycl::read_only);
148+ sycl::accessor OutAcc (OutBuf, CGH, sycl::write_only);
149+ CGH.parallel_for (Size,
150+ [=](sycl::id<1 > Id) { OutAcc[Id] = Func (InAcc[Id]); });
151+ }).wait ();
152+
153+ sycl::host_accessor Acc (OutBuf, sycl::read_only);
154+ for (size_t Idx = 0 ; Idx < Size; ++Idx) {
155+ if (EquTy ()(HostRef[Idx], Acc[Idx]))
156+ continue ;
157+ std::cout << " Mismatch at line " << Line << " [" << Idx << " ]: " << Acc[Idx]
158+ << " != " << HostRef[Idx] << " , input was "
159+ << *(std::begin (Input) + Idx) << std::endl;
160+ assert (false );
161+ }
162+ }
163+
75164template <class InputTy , class OutputTy , class FuncTy ,
76165 class EquTy = imf_utils_default_equ<OutputTy>>
77166void test (sycl::queue &q, std::initializer_list<InputTy> Input,
0 commit comments