@@ -116,10 +116,49 @@ broadcast_like(const V& input, const R& other) {
116
116
return broadcast (input, vector_extent_type<R> {});
117
117
}
118
118
119
+ /* *
120
+ * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
121
+ * operations are performed without any approximations or optimizations that could potentially alter the precise
122
+ * outcome of the computations
123
+ */
124
+ struct accurate_policy {};
125
+
126
+ /* *
127
+ * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving
128
+ * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve
129
+ * approximations that slightly compromise precision.
130
+ */
131
+ struct fast_policy {};
132
+
133
+ /* *
134
+ * This template policy allows developers to specify a custom degree of approximation for their computations. By
135
+ * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the
136
+ * specific needs of your application. Higher values mean more precision.
137
+ */
138
+ template <int Level = -1 >
139
+ struct approx_level_policy {};
140
+
141
+ /* *
142
+ * The approximate_policy serves as the default approximation policy, providing a standard level of approximation
143
+ * without requiring explicit configuration. It balances accuracy and performance, making it suitable for
144
+ * general-purpose use cases where neither extreme precision nor maximum speed is necessary.
145
+ */
146
+ using approx_policy = approx_level_policy<>;
147
+
148
+ #ifndef KERNEL_FLOAT_POLICY
149
+ #define KERNEL_FLOAT_POLICY accurate_policy;
150
+ #endif
151
+
152
+ /* *
153
+ * The `default_policy` acts as the standard computation policy. It can be configured externally using the
154
+ * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`.
155
+ */
156
+ using default_policy = KERNEL_FLOAT_POLICY;
157
+
119
158
namespace detail {
120
159
121
- template <typename F, size_t N, typename Output, typename ... Args>
122
- struct apply_impl {
160
+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
161
+ struct apply_base_impl {
123
162
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
124
163
#pragma unroll
125
164
for (size_t i = 0 ; i < N; i++) {
@@ -128,49 +167,31 @@ struct apply_impl {
128
167
}
129
168
};
130
169
131
- template <typename F, size_t N, typename Output, typename ... Args>
132
- struct apply_fastmath_impl : apply_impl<F, N, Output, Args...> {};
133
-
134
- template <int Deg, typename F, size_t N, typename Output, typename ... Args>
135
- struct apply_approx_impl : apply_fastmath_impl<F, N, Output, Args...> {};
136
- } // namespace detail
137
-
138
- struct accurate_policy {
139
- template <typename F, size_t N, typename Output, typename ... Args>
140
- using type = detail::apply_impl<F, N, Output, Args...>;
141
- };
142
-
143
- struct fast_policy {
144
- template <typename F, size_t N, typename Output, typename ... Args>
145
- using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
146
- };
147
-
148
- template <int Degree = -1 >
149
- struct approximate_policy {
150
- template <typename F, size_t N, typename Output, typename ... Args>
151
- using type = detail::apply_approx_impl<Degree, F, N, Output, Args...>;
152
- };
170
+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
171
+ struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
153
172
154
- using default_approximate_policy = approximate_policy<>;
173
+ template <typename F, size_t N, typename Output, typename ... Args>
174
+ struct apply_base_impl <fast_policy, F, N, Output, Args...>:
175
+ apply_impl<accurate_policy, F, N, Output, Args...> {};
155
176
156
- #ifdef KERNEL_FLOAT_POLICY
157
- using default_policy = KERNEL_FLOAT_POLICY;
158
- #else
159
- using default_policy = accurate_policy;
160
- #endif
177
+ template <typename F, size_t N, typename Output, typename ... Args>
178
+ struct apply_base_impl <approx_policy, F, N, Output, Args...>:
179
+ apply_impl<fast_policy, F, N, Output, Args...> {};
161
180
162
- namespace detail {
181
+ template <int Level, typename F, size_t N, typename Output, typename ... Args>
182
+ struct apply_base_impl <approx_level_policy<Level>, F, N, Output, Args...>:
183
+ apply_impl<approx_policy, F, N, Output, Args...> {};
163
184
164
185
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
165
- struct map_policy_impl {
186
+ struct map_impl {
166
187
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
167
188
static constexpr size_t remainder = N % packet_size;
168
189
169
190
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
170
191
if constexpr (N / packet_size > 0 ) {
171
192
#pragma unroll
172
193
for (size_t i = 0 ; i < N - remainder; i += packet_size) {
173
- Policy:: template type< F, packet_size, Output, Args...>::call (
194
+ apply_impl< Policy, F, packet_size, Output, Args...>::call (
174
195
fun,
175
196
output + i,
176
197
(args + i)...);
@@ -180,14 +201,14 @@ struct map_policy_impl {
180
201
if constexpr (remainder > 0 ) {
181
202
#pragma unroll
182
203
for (size_t i = N - remainder; i < N; i++) {
183
- Policy:: template type< F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
204
+ apply_impl< Policy, F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
184
205
}
185
206
}
186
207
}
187
208
};
188
209
189
210
template <typename F, size_t N, typename Output, typename ... Args>
190
- using map_impl = map_policy_impl <default_policy, F, N, Output, Args...>;
211
+ using default_map_impl = map_impl <default_policy, F, N, Output, Args...>;
191
212
192
213
} // namespace detail
193
214
@@ -211,7 +232,7 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
211
232
using E = broadcast_vector_extent_type<Args...>;
212
233
vector_storage<Output, extent_size<E>> result;
213
234
214
- detail::map_policy_impl <Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call (
235
+ detail::map_impl <Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call (
215
236
fun,
216
237
result.data (),
217
238
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
0 commit comments