Skip to content

Commit a4ed4a7

Browse files
committed
[FORK][FIX] Fix int4 simple reorder
1 parent c43dd46 commit a4ed4a7

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/cpu/reorder/simple_reorder.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,20 +2567,28 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
25672567
const dim_t work_amount = input_d.nelems();
25682568

25692569
auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t {
2570-
uint8_t shift = high_half ? 4 : 0;
2571-
2572-
return (uint8_t)((val >> shift) & 0x000F);
2570+
if (high_half) {
2571+
return (uint8_t)(val >> 4);
2572+
}
2573+
return (uint8_t)(val & 0x0F);
25732574
};
25742575

25752576
parallel(0, [&](const int ithr, const int nthr) {
25762577
dim_t start {0}, end {0};
25772578
balance211(work_amount, nthr, ithr, start, end);
2579+
const auto* u8_input = reinterpret_cast<const uint8_t *>(input);
25782580
if (utils::one_of(type_i, dnnl_s4, dnnl_u4)) {
25792581
PRAGMA_OMP_SIMD()
25802582
for (dim_t idx = start; idx < end; idx++) {
25812583
const auto i_off = input_d.off_l(idx);
25822584
const auto o_off = output_d.off_l(idx);
2583-
const int8_t src_val = extract_half_byte(input[i_off / 2], i_off % 2);
2585+
const uint8_t extracted = extract_half_byte(u8_input[i_off / 2], i_off % 2);
2586+
2587+
int8_t src_val = extracted;
2588+
if (type_i == dnnl_s4) {
2589+
// Sign extension for s4: if bit 3 is set, extend with 1s
2590+
src_val = (extracted & 0x08) ? (extracted | 0xF0) : extracted;
2591+
}
25842592
output[o_off] = _qz_a1b0<dnnl_s8, type_o>()(src_val);
25852593
}
25862594
} else {
@@ -2605,7 +2613,7 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
26052613
for (dim_t idx = start; idx < end; idx++) {
26062614
const auto i_off = input_d.off_l(idx);
26072615
const auto o_off = output_d.off_l(idx);
2608-
const uint8_t idx_val = extract_half_byte(input[i_off / 2], i_off % 2);
2616+
const uint8_t idx_val = extract_half_byte(u8_input[i_off / 2], i_off % 2);
26092617
output[o_off] = lookup[idx_val];
26102618
}
26112619
}

0 commit comments

Comments
 (0)