@@ -4,9 +4,15 @@ use crate::helpers::{
4
4
} ;
5
5
use crate :: traits:: Kernel ;
6
6
use crate :: types:: EvalType ;
7
+ use crate :: RealScalar ;
8
+ use crate :: { ComplexScalar , SimdFor } ;
7
9
use num:: traits:: FloatConst ;
10
+ use num:: One ;
8
11
use num:: Zero ;
12
+ use pulp:: Simd ;
9
13
use rayon:: prelude:: * ;
14
+ use rlst:: c32;
15
+ use rlst:: c64;
10
16
use rlst:: RlstScalar ;
11
17
use std:: marker:: PhantomData ;
12
18
@@ -251,8 +257,8 @@ where
251
257
/// Evaluate Helmholtz kernel for one target
252
258
pub fn evaluate_helmholtz_one_target < T : RlstScalar < Complex = T > > (
253
259
eval_type : EvalType ,
254
- target : & [ < T as RlstScalar > :: Real ] ,
255
- sources : & [ < T as RlstScalar > :: Real ] ,
260
+ target : & [ T :: Real ] ,
261
+ sources : & [ T :: Real ] ,
256
262
charges : & [ T ] ,
257
263
wavenumber : T :: Real ,
258
264
result : & mut [ T ] ,
@@ -273,33 +279,159 @@ pub fn evaluate_helmholtz_one_target<T: RlstScalar<Complex = T>>(
273
279
274
280
match eval_type {
275
281
EvalType :: Value => {
276
- let mut my_result_real = <<T as RlstScalar >:: Real as Zero >:: zero ( ) ;
277
- let mut my_result_imag = <<T as RlstScalar >:: Real as Zero >:: zero ( ) ;
278
- for index in 0 ..nsources {
279
- diff0 = sources0[ index] - target[ 0 ] ;
280
- diff1 = sources1[ index] - target[ 1 ] ;
281
- diff2 = sources2[ index] - target[ 2 ] ;
282
- let diff_norm = ( diff0 * diff0 + diff1 * diff1 + diff2 * diff2) . sqrt ( ) ;
283
- let inv_diff_norm = {
284
- if diff_norm == zero_real {
285
- zero_real
286
- } else {
287
- one_real / diff_norm
282
+ struct Impl < ' a , T : RlstScalar > {
283
+ wavenumber : T :: Real ,
284
+ t0 : T :: Real ,
285
+ t1 : T :: Real ,
286
+ t2 : T :: Real ,
287
+
288
+ sources0 : & ' a [ T :: Real ] ,
289
+ sources1 : & ' a [ T :: Real ] ,
290
+ sources2 : & ' a [ T :: Real ] ,
291
+ charges : & ' a [ T ] ,
292
+ }
293
+
294
+ impl < T : ComplexScalar > pulp:: WithSimd for Impl < ' _ , T > {
295
+ type Output = ( T :: Real , T :: Real ) ;
296
+
297
+ #[ inline( always) ]
298
+ fn with_simd < S : pulp:: Simd > ( self , simd : S ) -> Self :: Output {
299
+ use coe:: Coerce ;
300
+
301
+ let Self {
302
+ wavenumber,
303
+ t0,
304
+ t1,
305
+ t2,
306
+ sources0,
307
+ sources1,
308
+ sources2,
309
+ charges,
310
+ } = self ;
311
+
312
+ let ( s0_head, s0_tail) = T :: Real :: as_simd_slice :: < S > ( sources0) ;
313
+ let ( s1_head, s1_tail) = T :: Real :: as_simd_slice :: < S > ( sources1) ;
314
+ let ( s2_head, s2_tail) = T :: Real :: as_simd_slice :: < S > ( sources2) ;
315
+
316
+ let len = s0_head. len ( ) ;
317
+ let n = std:: mem:: size_of :: < <T :: Real as RealScalar >:: Scalars < S > > ( )
318
+ / std:: mem:: size_of :: < T :: Real > ( ) ;
319
+ let ( c_head, c_tail) = charges. split_at ( len * n) ;
320
+ let c_head: & [ [ <T :: Real as RealScalar >:: Scalars < S > ; 2 ] ] =
321
+ bytemuck:: cast_slice ( c_head) ;
322
+ let c_tail: & [ [ T :: Real ; 2 ] ] = bytemuck:: cast_slice ( c_tail) ;
323
+
324
+ #[ inline( always) ]
325
+ fn impl_slice < T : ComplexScalar , S : Simd > (
326
+ simd : S ,
327
+ wavenumber : T :: Real ,
328
+ t0 : T :: Real ,
329
+ t1 : T :: Real ,
330
+ t2 : T :: Real ,
331
+
332
+ sources0 : & [ <T :: Real as RealScalar >:: Scalars < S > ] ,
333
+ sources1 : & [ <T :: Real as RealScalar >:: Scalars < S > ] ,
334
+ sources2 : & [ <T :: Real as RealScalar >:: Scalars < S > ] ,
335
+ charges : & [ [ <T :: Real as RealScalar >:: Scalars < S > ; 2 ] ] ,
336
+ ) -> ( T :: Real , T :: Real ) {
337
+ let simd = SimdFor :: < T :: Real , S > :: new ( simd) ;
338
+
339
+ let t0 = simd. splat ( t0) ;
340
+ let t1 = simd. splat ( t1) ;
341
+ let t2 = simd. splat ( t2) ;
342
+ let zero = simd. splat ( T :: Real :: zero ( ) ) ;
343
+ let wavenumber = simd. splat ( wavenumber) ;
344
+ let mut acc_re = simd. splat ( T :: Real :: zero ( ) ) ;
345
+ let mut acc_im = simd. splat ( T :: Real :: zero ( ) ) ;
346
+
347
+ for ( & s0, & s1, & s2, & c) in
348
+ itertools:: izip!( sources0, sources1, sources2, charges)
349
+ {
350
+ let [ c_re, c_im] = simd. deinterleave ( c) ;
351
+
352
+ let diff0 = simd. sub ( s0, t0) ;
353
+ let diff1 = simd. sub ( s1, t1) ;
354
+ let diff2 = simd. sub ( s2, t2) ;
355
+
356
+ let diff_norm = simd. sqrt ( simd. mul_add (
357
+ diff0,
358
+ diff0,
359
+ simd. mul_add ( diff1, diff1, simd. mul ( diff2, diff2) ) ,
360
+ ) ) ;
361
+
362
+ let is_zero = simd. cmp_eq ( diff_norm, zero) ;
363
+ let inv_diff_norm = simd. select (
364
+ is_zero,
365
+ zero,
366
+ simd. div ( simd. splat ( T :: Real :: one ( ) ) , diff_norm) ,
367
+ ) ;
368
+ let kr = simd. mul ( wavenumber, diff_norm) ;
369
+
370
+ let ( g_re, g_im) = {
371
+ let ( s, c) = simd. sin_cos ( kr) ;
372
+ ( simd. mul ( c, inv_diff_norm) , simd. mul ( s, inv_diff_norm) )
373
+ } ;
374
+
375
+ acc_re = simd. mul_add (
376
+ g_re,
377
+ c_re,
378
+ simd. mul_add ( simd. neg ( g_im) , c_im, acc_re) ,
379
+ ) ;
380
+ acc_im = simd. mul_add ( g_re, c_im, simd. mul_add ( g_im, c_re, acc_im) ) ;
381
+ }
382
+ ( simd. reduce_add ( acc_re) , simd. reduce_add ( acc_im) )
288
383
}
289
- } ;
290
384
291
- let kr = wavenumber * diff_norm;
385
+ let ( re0, im0) = impl_slice :: < T , S > (
386
+ simd, wavenumber, t0, t1, t2, s0_head, s1_head, s2_head, c_head,
387
+ ) ;
388
+ let ( re1, im1) = impl_slice :: < T , pulp:: Scalar > (
389
+ pulp:: Scalar :: new ( ) ,
390
+ wavenumber,
391
+ t0,
392
+ t1,
393
+ t2,
394
+ s0_tail. coerce ( ) ,
395
+ s1_tail. coerce ( ) ,
396
+ s2_tail. coerce ( ) ,
397
+ c_tail. coerce ( ) ,
398
+ ) ;
292
399
293
- let g_re = <T :: Real as RlstScalar >:: cos ( kr) * inv_diff_norm;
294
- let g_im = <T :: Real as RlstScalar >:: sin ( kr) * inv_diff_norm;
295
- let charge_re = charges[ index] . re ( ) ;
296
- let charge_im = charges[ index] . im ( ) ;
400
+ ( re0 + re1, im0 + im1)
401
+ }
402
+ }
297
403
298
- my_result_imag += g_re * charge_im + g_im * charge_re;
299
- my_result_real += g_re * charge_re - g_im * charge_im;
404
+ use coe:: coerce_static as to;
405
+ use coe:: Coerce ;
406
+ if coe:: is_same :: < T , c32 > ( ) {
407
+ let ( re, im) = pulp:: Arch :: new ( ) . dispatch ( Impl :: < ' _ , c32 > {
408
+ wavenumber : to ( wavenumber) ,
409
+ t0 : to ( target[ 0 ] ) ,
410
+ t1 : to ( target[ 1 ] ) ,
411
+ t2 : to ( target[ 2 ] ) ,
412
+ sources0 : sources0. coerce ( ) ,
413
+ sources1 : sources1. coerce ( ) ,
414
+ sources2 : sources2. coerce ( ) ,
415
+ charges : charges. coerce ( ) ,
416
+ } ) ;
417
+ result[ 0 ] += T :: Complex :: complex ( to :: < _ , T :: Real > ( re) , to :: < _ , T :: Real > ( im) )
418
+ . mul_real ( m_inv_4pi) ;
419
+ } else if coe:: is_same :: < T , c64 > ( ) {
420
+ let ( re, im) = pulp:: Arch :: new ( ) . dispatch ( Impl :: < ' _ , c64 > {
421
+ wavenumber : to ( wavenumber) ,
422
+ t0 : to ( target[ 0 ] ) ,
423
+ t1 : to ( target[ 1 ] ) ,
424
+ t2 : to ( target[ 2 ] ) ,
425
+ sources0 : sources0. coerce ( ) ,
426
+ sources1 : sources1. coerce ( ) ,
427
+ sources2 : sources2. coerce ( ) ,
428
+ charges : charges. coerce ( ) ,
429
+ } ) ;
430
+ result[ 0 ] += T :: Complex :: complex ( to :: < _ , T :: Real > ( re) , to :: < _ , T :: Real > ( im) )
431
+ . mul_real ( m_inv_4pi) ;
432
+ } else {
433
+ panic ! ( )
300
434
}
301
- result[ 0 ] += <T :: Complex as RlstScalar >:: complex ( my_result_real, my_result_imag)
302
- . mul_real ( m_inv_4pi) ;
303
435
}
304
436
EvalType :: ValueDeriv => {
305
437
// Cannot simply use an array my_result as this is not
0 commit comments