@@ -188,23 +188,26 @@ fn vpdpbusd<'tcx>(
188188 let ( b, b_len) = ecx. project_to_simd ( b) ?;
189189 let ( dest, dest_len) = ecx. project_to_simd ( dest) ?;
190190
191- // fn vpdpbusd(src: i32x16, a: i32x16 , b: i32x16 ) -> i32x16;
192- // fn vpdpbusd256(src: i32x8, a: i32x8 , b: i32x8 ) -> i32x8;
193- // fn vpdpbusd128(src: i32x4, a: i32x4 , b: i32x4 ) -> i32x4;
191+ // fn vpdpbusd(src: i32x16, a: u8x64 , b: i8x64 ) -> i32x16;
192+ // fn vpdpbusd256(src: i32x8, a: u8x32 , b: i8x32 ) -> i32x8;
193+ // fn vpdpbusd128(src: i32x4, a: u8x16 , b: i8x16 ) -> i32x4;
194194 assert_eq ! ( dest_len, src_len) ;
195- assert_eq ! ( dest_len, a_len) ;
196- assert_eq ! ( dest_len , b_len) ;
195+ assert_eq ! ( dest_len * 4 , a_len) ;
196+ assert_eq ! ( a_len , b_len) ;
197197
198198 for i in 0 ..dest_len {
199199 let src = ecx. read_scalar ( & ecx. project_index ( & src, i) ?) ?. to_i32 ( ) ?;
200- let a = ecx. read_scalar ( & ecx. project_index ( & a, i) ?) ?. to_u32 ( ) ?;
201- let b = ecx. read_scalar ( & ecx. project_index ( & b, i) ?) ?. to_u32 ( ) ?;
202200 let dest = ecx. project_index ( & dest, i) ?;
203201
204- let zipped = a. to_le_bytes ( ) . into_iter ( ) . zip ( b. to_le_bytes ( ) ) ;
205- let intermediate_sum: i32 = zipped
206- . map ( |( a, b) | i32:: from ( a) . strict_mul ( i32:: from ( b. cast_signed ( ) ) ) )
207- . fold ( 0 , |x, y| x. strict_add ( y) ) ;
202+ let mut intermediate_sum: i32 = 0 ;
203+ for j in 0 ..4 {
204+ let idx = i. strict_mul ( 4 ) . strict_add ( j) ;
205+ let a = ecx. read_scalar ( & ecx. project_index ( & a, idx) ?) ?. to_u8 ( ) ?;
206+ let b = ecx. read_scalar ( & ecx. project_index ( & b, idx) ?) ?. to_i8 ( ) ?;
207+
208+ let product = i32:: from ( a) . strict_mul ( i32:: from ( b) ) ;
209+ intermediate_sum = intermediate_sum. strict_add ( product) ;
210+ }
208211
209212 // Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
210213 let res = Scalar :: from_i32 ( intermediate_sum. wrapping_add ( src) ) ;
0 commit comments