Skip to content

Commit ae0e6b1

Browse files
committed
Simplify how policies are implemented internally
1 parent 4fca9cb commit ae0e6b1

File tree

11 files changed

+295
-208
lines changed

11 files changed

+295
-208
lines changed

docs/guides/accuracy.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ kf::vec<float, 4> c = kf::fast_rcp(x);
2525
kf::vec<float, 4> d = kf::fast_div(a, b);
2626
```
2727
28-
These functions are only functional for 32-bit and 16-bit floats.
28+
These functions are only functional for 32-bit and 16-bit floats.
2929
For other input types, the operation falls back to the regular version.
3030
3131
## Approximate Math
3232
33-
For 16-bit floats, several approximate functions are provided.
34-
These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions.
33+
For 16-bit floats, several approximate functions are provided.
34+
These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions.
3535
This can be very fast but also less accurate.
3636
3737
@@ -69,14 +69,15 @@ kf::vec<half, 4> a = kf::approx_sin<3>(x);
6969

7070
## Tuning Accuracy Level
7171

72-
Many functions in Kernel Float accept an additional Accuracy option as a template parameter.
72+
Many functions in Kernel Float accept an additional `Accuracy` option as a template parameter.
7373
This allows you to tune the accuracy level without changing the function name.
7474

75-
There are four possible values for this parameter:
75+
There are five possible values for this parameter:
7676

7777
- `kf::accurate_policy`: Use the most accurate version of the function available.
7878
- `kf::fast_policy`: Use the "fast math" version.
79-
- `kf::approx_policy<N>`: Use the approximate version with degree `N`.
79+
- `kf::approx_level_policy<N>`: Use the approximate version with accuracy level `N` (higher is more accurate).
80+
- `kf::approx_policy`: Use the approximate version with a default accuracy level.
8081
- `kf::default_policy`: Use a global default policy (see the next section).
8182

8283
For example, consider this code:
@@ -97,15 +98,19 @@ kf::vec<float, 2> c = kf::cos<kf::accurate_policy>(input);
9798
kf::vec<float, 2> d = kf::cos<kf::fast_policy>(input);
9899

99100
// Use the approximate policy
100-
kf::vec<float, 2> e = kf::cos<kf::approx_policy<3>>(input);
101+
kf::vec<float, 2> e = kf::cos<kf::approx_policy>(input);
102+
103+
// Use the approximate policy with degree 3 polynomial.
104+
kf::vec<float, 2> f = kf::cos<kf::approx_level_policy<3>>(input);
101105

102106
// You can use aliases to define your own policy
103107
using my_own_policy = kf::fast_policy;
104-
kf::vec<float, 2> f = kf::cos<my_own_policy>(input);
108+
kf::vec<float, 2> g = kf::cos<my_own_policy>(input);
105109
```
106110
107111
## Setting `default_policy`
108112
113+
If no policy is explicitly set, any function use the `kf::default_policy`.
109114
By default, `kf::default_policy` is set to `kf::accurate_policy`.
110115
111116
Set the preprocessor option `KERNEL_FLOAT_FAST_MATH=1` to change the default policy to `kf::fast_policy`.

include/kernel_float/apply.h

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,49 @@ broadcast_like(const V& input, const R& other) {
116116
return broadcast(input, vector_extent_type<R> {});
117117
}
118118

119+
/**
120+
* The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
121+
* operations are performed without any approximations or optimizations that could potentially alter the precise
122+
* outcome of the computations
123+
*/
124+
struct accurate_policy {};
125+
126+
/**
127+
* The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving
128+
* the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve
129+
* approximations that slightly compromise precision.
130+
*/
131+
struct fast_policy {};
132+
133+
/**
134+
* This template policy allows developers to specify a custom degree of approximation for their computations. By
135+
* adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the
136+
* specific needs of your application. Higher values mean more precision.
137+
*/
138+
template<int Level = -1>
139+
struct approx_level_policy {};
140+
141+
/**
142+
* The approximate_policy serves as the default approximation policy, providing a standard level of approximation
143+
* without requiring explicit configuration. It balances accuracy and performance, making it suitable for
144+
* general-purpose use cases where neither extreme precision nor maximum speed is necessary.
145+
*/
146+
using approx_policy = approx_level_policy<>;
147+
148+
#ifndef KERNEL_FLOAT_POLICY
149+
#define KERNEL_FLOAT_POLICY accurate_policy;
150+
#endif
151+
152+
/**
153+
* The `default_policy` acts as the standard computation policy. It can be configured externally using the
154+
* `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`.
155+
*/
156+
using default_policy = KERNEL_FLOAT_POLICY;
157+
119158
namespace detail {
120159

121-
template<typename F, size_t N, typename Output, typename... Args>
122-
struct apply_impl {
160+
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
161+
struct apply_base_impl {
123162
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
124163
#pragma unroll
125164
for (size_t i = 0; i < N; i++) {
@@ -128,49 +167,31 @@ struct apply_impl {
128167
}
129168
};
130169

131-
template<typename F, size_t N, typename Output, typename... Args>
132-
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
133-
134-
template<int Deg, typename F, size_t N, typename Output, typename... Args>
135-
struct apply_approx_impl: apply_fastmath_impl<F, N, Output, Args...> {};
136-
} // namespace detail
137-
138-
struct accurate_policy {
139-
template<typename F, size_t N, typename Output, typename... Args>
140-
using type = detail::apply_impl<F, N, Output, Args...>;
141-
};
142-
143-
struct fast_policy {
144-
template<typename F, size_t N, typename Output, typename... Args>
145-
using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
146-
};
147-
148-
template<int Degree = -1>
149-
struct approximate_policy {
150-
template<typename F, size_t N, typename Output, typename... Args>
151-
using type = detail::apply_approx_impl<Degree, F, N, Output, Args...>;
152-
};
170+
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
171+
struct apply_impl: apply_base_impl<Policy, F, N, Output, Args...> {};
153172

154-
using default_approximate_policy = approximate_policy<>;
173+
template<typename F, size_t N, typename Output, typename... Args>
174+
struct apply_base_impl<fast_policy, F, N, Output, Args...>:
175+
apply_impl<accurate_policy, F, N, Output, Args...> {};
155176

156-
#ifdef KERNEL_FLOAT_POLICY
157-
using default_policy = KERNEL_FLOAT_POLICY;
158-
#else
159-
using default_policy = accurate_policy;
160-
#endif
177+
template<typename F, size_t N, typename Output, typename... Args>
178+
struct apply_base_impl<approx_policy, F, N, Output, Args...>:
179+
apply_impl<fast_policy, F, N, Output, Args...> {};
161180

162-
namespace detail {
181+
template<int Level, typename F, size_t N, typename Output, typename... Args>
182+
struct apply_base_impl<approx_level_policy<Level>, F, N, Output, Args...>:
183+
apply_impl<approx_policy, F, N, Output, Args...> {};
163184

164185
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
165-
struct map_policy_impl {
186+
struct map_impl {
166187
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
167188
static constexpr size_t remainder = N % packet_size;
168189

169190
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
170191
if constexpr (N / packet_size > 0) {
171192
#pragma unroll
172193
for (size_t i = 0; i < N - remainder; i += packet_size) {
173-
Policy::template type<F, packet_size, Output, Args...>::call(
194+
apply_impl<Policy, F, packet_size, Output, Args...>::call(
174195
fun,
175196
output + i,
176197
(args + i)...);
@@ -180,14 +201,14 @@ struct map_policy_impl {
180201
if constexpr (remainder > 0) {
181202
#pragma unroll
182203
for (size_t i = N - remainder; i < N; i++) {
183-
Policy::template type<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
204+
apply_impl<Policy, F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
184205
}
185206
}
186207
}
187208
};
188209

189210
template<typename F, size_t N, typename Output, typename... Args>
190-
using map_impl = map_policy_impl<default_policy, F, N, Output, Args...>;
211+
using default_map_impl = map_impl<default_policy, F, N, Output, Args...>;
191212

192213
} // namespace detail
193214

@@ -211,7 +232,7 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
211232
using E = broadcast_vector_extent_type<Args...>;
212233
vector_storage<Output, extent_size<E>> result;
213234

214-
detail::map_policy_impl<Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call(
235+
detail::map_impl<Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call(
215236
fun,
216237
result.data(),
217238
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(

include/kernel_float/approx.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -359,25 +359,25 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) {
359359
#endif
360360
} // namespace approx
361361

362-
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
363-
namespace detail { \
364-
template<int Degree> \
365-
struct apply_approx_impl<Deg, ops::FUN<__half>, 2, __half, __half> { \
366-
KERNEL_FLOAT_INLINE static void \
367-
call(ops::FUN<__half> fun, __half* output, const __half* input) { \
368-
__half2 res = approx::FUN<Degree>(__half2 {input[0], input[1]}); \
369-
output[0] = res.x; \
370-
output[1] = res.y; \
371-
} \
372-
}; \
373-
template<> \
374-
struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \
375-
apply_approx_impl<DEG, ops::FUN<__half>, 2, __half, __half> {}; \
376-
} \
377-
\
378-
template<typename V> \
379-
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
380-
return map<approximate_policy<>>(ops::FUN<vector_value_type<V>> {}, args); \
362+
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
363+
namespace detail { \
364+
template<int Degree> \
365+
struct apply_impl<approx_level_policy<Degree>, ops::FUN<__half>, 2, __half, __half> { \
366+
KERNEL_FLOAT_INLINE static void \
367+
call(ops::FUN<__half> fun, __half* output, const __half* input) { \
368+
__half2 res = approx::FUN<Degree>(__half2 {input[0], input[1]}); \
369+
output[0] = res.x; \
370+
output[1] = res.y; \
371+
} \
372+
}; \
373+
template<> \
374+
struct apply_impl<approx_policy, ops::FUN<__half>, 2, __half, __half>: \
375+
apply_impl<approx_level_policy<DEG>, ops::FUN<__half>, 2, __half, __half> {}; \
376+
} \
377+
\
378+
template<int Level = -1, typename V> \
379+
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
380+
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
381381
}
382382

383383
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4)

include/kernel_float/bf16.h

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,24 @@ struct allow_float_fallback<__bfloat16> {
6161
}; // namespace detail
6262

6363
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
64-
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
65-
namespace ops { \
66-
template<> \
67-
struct NAME<__bfloat16> { \
68-
KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \
69-
return FUN1(input); \
70-
} \
71-
}; \
72-
} \
73-
namespace detail { \
74-
template<> \
75-
struct apply_impl<ops::NAME<__bfloat16>, 2, __bfloat16, __bfloat16> { \
76-
KERNEL_FLOAT_INLINE static void \
77-
call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \
78-
__bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \
79-
result[0] = r.x, result[1] = r.y; \
80-
} \
81-
}; \
64+
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
65+
namespace ops { \
66+
template<> \
67+
struct NAME<__bfloat16> { \
68+
KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \
69+
return FUN1(input); \
70+
} \
71+
}; \
72+
} \
73+
namespace detail { \
74+
template<> \
75+
struct apply_impl<accurate_policy, ops::NAME<__bfloat16>, 2, __bfloat16, __bfloat16> { \
76+
KERNEL_FLOAT_INLINE static void \
77+
call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \
78+
__bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \
79+
result[0] = r.x, result[1] = r.y; \
80+
} \
81+
}; \
8282
}
8383

8484
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
@@ -115,7 +115,13 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
115115
} \
116116
namespace detail { \
117117
template<> \
118-
struct apply_impl<ops::NAME<__bfloat16>, 2, __bfloat16, __bfloat16, __bfloat16> { \
118+
struct apply_impl< \
119+
accurate_policy, \
120+
ops::NAME<__bfloat16>, \
121+
2, \
122+
__bfloat16, \
123+
__bfloat16, \
124+
__bfloat16> { \
119125
KERNEL_FLOAT_INLINE static void call( \
120126
ops::NAME<__bfloat16>, \
121127
__bfloat16* result, \
@@ -154,7 +160,14 @@ struct fma<__bfloat16> {
154160

155161
namespace detail {
156162
template<>
157-
struct apply_impl<ops::fma<__bfloat16>, 2, __bfloat16, __bfloat16, __bfloat16, __bfloat16> {
163+
struct apply_impl<
164+
accurate_policy,
165+
ops::fma<__bfloat16>,
166+
2,
167+
__bfloat16,
168+
__bfloat16,
169+
__bfloat16,
170+
__bfloat16> {
158171
KERNEL_FLOAT_INLINE static void call(
159172
ops::fma<__bfloat16>,
160173
__bfloat16* result,

include/kernel_float/binops.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
5252

5353
vector_storage<O, extent_size<E>> result;
5454

55-
detail::map_impl<F, extent_size<E>, O, T, T>::call(
55+
detail::default_map_impl<F, extent_size<E>, O, T, T>::call(
5656
fun,
5757
result.data(),
5858
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -290,21 +290,25 @@ struct multiply<bool> {
290290
}; // namespace ops
291291

292292
namespace detail {
293-
template<typename T, size_t N>
294-
struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
293+
template<typename Policy, typename T, size_t N>
294+
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
295295
KERNEL_FLOAT_INLINE static void
296296
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
297297
T rhs_rcp[N];
298298

299299
// Fast way to perform division is to multiply by the reciprocal
300-
apply_fastmath_impl<ops::rcp<T>, N, T, T>::call({}, rhs_rcp, rhs);
301-
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
300+
apply_impl<Policy, ops::rcp<T>, N, T, T>::call({}, rhs_rcp, rhs);
301+
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
302302
}
303303
};
304304

305+
template<typename T, size_t N>
306+
struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
307+
apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};
308+
305309
#if KERNEL_FLOAT_IS_DEVICE
306310
template<>
307-
struct apply_fastmath_impl<ops::divide<float>, 1, float, float, float> {
311+
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
308312
KERNEL_FLOAT_INLINE static void
309313
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
310314
*result = __fdividef(*lhs, *rhs);
@@ -319,7 +323,7 @@ fast_divide(const L& left, const R& right) {
319323
using E = broadcast_vector_extent_type<L, R>;
320324
vector_storage<T, extent_size<E>> result;
321325

322-
detail::map_policy_impl<fast_policy, ops::divide<T>, extent_size<E>, T, T, T>::call(
326+
detail::map_impl<fast_policy, ops::divide<T>, extent_size<E>, T, T, T>::call(
323327
ops::divide<T> {},
324328
result.data(),
325329
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(

0 commit comments

Comments
 (0)