@@ -2567,20 +2567,28 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
25672567 const dim_t work_amount = input_d.nelems ();
25682568
25692569 auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t {
2570- uint8_t shift = high_half ? 4 : 0 ;
2571-
2572- return (uint8_t )((val >> shift) & 0x000F );
2570+ if (high_half) {
2571+ return (uint8_t )(val >> 4 );
2572+ }
2573+ return (uint8_t )(val & 0x0F );
25732574 };
25742575
25752576 parallel (0 , [&](const int ithr, const int nthr) {
25762577 dim_t start {0 }, end {0 };
25772578 balance211 (work_amount, nthr, ithr, start, end);
2579+ const auto * u8_input = reinterpret_cast <const uint8_t *>(input);
25782580 if (utils::one_of (type_i, dnnl_s4, dnnl_u4)) {
25792581 PRAGMA_OMP_SIMD ()
25802582 for (dim_t idx = start; idx < end; idx++) {
25812583 const auto i_off = input_d.off_l (idx);
25822584 const auto o_off = output_d.off_l (idx);
2583- const int8_t src_val = extract_half_byte (input[i_off / 2 ], i_off % 2 );
2585+ const uint8_t extracted = extract_half_byte (u8_input[i_off / 2 ], i_off % 2 );
2586+
2587+ int8_t src_val = extracted;
2588+ if (type_i == dnnl_s4) {
2589+ // Sign extension for s4: if bit 3 is set, extend with 1s
2590+ src_val = (extracted & 0x08 ) ? (extracted | 0xF0 ) : extracted;
2591+ }
25842592 output[o_off] = _qz_a1b0<dnnl_s8, type_o>()(src_val);
25852593 }
25862594 } else {
@@ -2605,7 +2613,7 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
26052613 for (dim_t idx = start; idx < end; idx++) {
26062614 const auto i_off = input_d.off_l (idx);
26072615 const auto o_off = output_d.off_l (idx);
2608- const uint8_t idx_val = extract_half_byte (input [i_off / 2 ], i_off % 2 );
2616+ const uint8_t idx_val = extract_half_byte (u8_input [i_off / 2 ], i_off % 2 );
26092617 output[o_off] = lookup[idx_val];
26102618 }
26112619 }
0 commit comments