Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 96d2966

Browse files
committed
s4=>i4
1 parent 696820f commit 96d2966

File tree

8 files changed

+44
-43
lines changed

8 files changed

+44
-43
lines changed

include/common/core/common_types.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
2929

3030
enum class quant_mode : uint8_t {
31-
S4_ASYM = 0,
32-
S4_FULLRANGE_NO_ZP = 1,
33-
INT4_ASYM_FP_ZERO = 2
31+
I4_ASYM = 0,
32+
I4_FULLRANGE_NO_ZP = 1,
33+
I4_ASYM_FP_ZERO = 2
3434
};
3535

3636
struct quant_info {

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class gemm_t<
102102
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
103103
"this is for 4bit matB ");
104104
static_assert(
105-
quant_info_.quant_mode == quant_mode::INT4_ASYM_FP_ZERO
105+
quant_info_.quant_mode == quant_mode::I4_ASYM_FP_ZERO
106106
? std::is_same_v<
107107
remove_const_t<dtype_zero_pt>,
108108
remove_const_t<dtype_a>>
@@ -291,7 +291,7 @@ class gemm_t<
291291

292292
// compress int4 along N dimensions
293293
using zero_pt_tile_desc_t = std::conditional_t<
294-
quant_info_.quant_mode != quant_mode::INT4_ASYM_FP_ZERO,
294+
quant_info_.quant_mode != quant_mode::I4_ASYM_FP_ZERO,
295295
subgroup::tile_desc_t<
296296
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
297297
tile_size_y_zero_pt,
@@ -535,7 +535,7 @@ class gemm_t<
535535
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
536536
scale_prefetch_payload);
537537
if constexpr (
538-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
538+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
539539
// TODO 1D prefetch need pack to U32/U64
540540
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
541541
zero_pt_prefetch_payload);
@@ -549,7 +549,7 @@ class gemm_t<
549549
scale_prefetch_payload.template update_tdesc<update_dir_b>(
550550
scale_t::tile_size_y);
551551
if constexpr (
552-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
552+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
553553
zero_pt_prefetch_payload
554554
.template update_tdesc<tdesc_update_dir::y_dir>(
555555
zero_pt_t::tile_size_y);
@@ -579,7 +579,7 @@ class gemm_t<
579579
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
580580
scale, scale_payload);
581581
if constexpr (
582-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
582+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
583583
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
584584
zero_pt, zero_pt_payload);
585585
}
@@ -594,7 +594,7 @@ class gemm_t<
594594
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
595595
scale_prefetch_payload);
596596
if constexpr (
597-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
597+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
598598
// TODO 1D prefetch need pack to U32/U64
599599
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
600600
zero_pt_prefetch_payload);
@@ -608,7 +608,7 @@ class gemm_t<
608608
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
609609
}
610610
if constexpr (
611-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
611+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
612612
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
613613
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
614614
zero_pt_t::tile_size_y);
@@ -623,7 +623,7 @@ class gemm_t<
623623
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
624624
scale_t::tile_size_y);
625625
if constexpr (
626-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
626+
compute_policy::quant_mode != quant_mode::I4_FULLRANGE_NO_ZP) {
627627
zero_pt_prefetch_payload
628628
.template update_tdesc<tdesc_update_dir::y_dir>(
629629
zero_pt_t::tile_size_y);

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class gemm_universal_t<
159159
/// @brief GEMM arguments.
160160
/// This is the interface for users to pass the application-related runtime
161161
/// variables.
162-
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
162+
template <quant_mode quant_mode = quant_mode::I4_FULLRANGE_NO_ZP>
163163
struct arguments_t {
164164
/// @brief Is the size of the m dimension of the matrix multiplication (m x
165165
/// k x n).
@@ -295,7 +295,7 @@ class gemm_universal_t<
295295
}
296296
};
297297
template <>
298-
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
298+
struct arguments_t<quant_mode::I4_FULLRANGE_NO_ZP> {
299299
/// @brief Is the size of the m dimension of the matrix multiplication (m x
300300
/// k x n).
301301
uint32_t matrix_m;
@@ -570,7 +570,7 @@ class gemm_universal_t<
570570
// check for int4x2
571571
implementable &=
572572
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
573-
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
573+
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
574574
implementable &= (args.zero_pt_ld % pack_ratio == 0);
575575
}
576576

@@ -622,7 +622,7 @@ class gemm_universal_t<
622622
int start_y_scale = start_k / dequant_s;
623623

624624
int start_x_zero_pt =
625-
gemm_t::compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO
625+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO
626626
? start_n
627627
: start_n / pack_ratio;
628628
int start_y_zero_pt = start_k / dequant_s;
@@ -671,15 +671,15 @@ class gemm_universal_t<
671671
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
672672
gemm_args_t gemm_args;
673673
if constexpr (
674-
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
674+
gemm_t::compute_policy::quant_mode == quant_mode::I4_FULLRANGE_NO_ZP) {
675675
gemm_args = gemm_args_t(
676676
mem_desc_a,
677677
mem_desc_b,
678678
inner_loop_start,
679679
inner_loop_count,
680680
mem_desc_scale);
681681
} else if constexpr (
682-
gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
682+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
683683
mem_desc_zero_pt_t mem_desc_zero_pt(
684684
args.zero_pt_base,
685685
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
@@ -694,7 +694,7 @@ class gemm_universal_t<
694694
mem_desc_scale,
695695
mem_desc_zero_pt);
696696
} else if constexpr (
697-
gemm_t::compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
697+
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
698698
mem_desc_zero_pt_t mem_desc_zero_pt(
699699
args.zero_pt_base,
700700
{args.matrix_n,

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
130130
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
131131
offset_x_in_tile;
132132

133-
if constexpr (quant_mode == quant_mode::S4_ASYM) {
133+
if constexpr (quant_mode == quant_mode::I4_ASYM) {
134134
uint32_t zero_pt_idx =
135135
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
136136
offset_x_in_tile / pack_ratio;
@@ -150,16 +150,16 @@ struct dequant_int4_weight_t {
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152152
} else if constexpr (
153-
quant_mode == quant_mode::S4_FULLRANGE_NO_ZP ||
154-
quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
153+
quant_mode == quant_mode::I4_FULLRANGE_NO_ZP ||
154+
quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
155155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
156156
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
157157
int8_t(8);
158158
}
159159
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
160160
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
161161
scale.reg[scale_idx];
162-
if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
162+
if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
163163
uint32_t zero_pt_idx =
164164
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165165
offset_x_in_tile;

tests/integration/gemm/int4_dequantization/main.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) {
229229
compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
230230
using perf_tuning_knob = xetla::group::
231231
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
232-
233-
static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b};
232+
233+
static constexpr quant_info quant_info{
234+
quant_mode::I4_ASYM, Test::dequant_s, layout_b};
234235

235236
using compute_policy = xetla::group::compute_policy_int4_dequantize<
236237
compute_attr,

tests/integration/gemm/int4_dequantization_bias/main_client.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) {
622622
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
623623

624624
static constexpr quant_info quant_info{
625-
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
625+
quant_mode::I4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
626626

627627
using compute_policy = xetla::group::compute_policy_int4_dequantize<
628628
compute_attr,
@@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd);
10431043
INSTANTIATE_TYPED_TEST_SUITE_P(
10441044
dequantize_gemm_act_shuf_test_suite,
10451045
dequantize_gemm_act_shuf_test,
1046-
tests);
1046+
tests);

tests/integration/gemm/int4_dequantization_bias/main_xe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) {
388388
using perf_tuning_knob = xetla::group::
389389
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
390390
static constexpr quant_info quant_info{
391-
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
391+
quant_mode::I4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
392392

393393
using compute_policy = xetla::group::compute_policy_int4_dequantize<
394394
compute_attr,

tests/integration/gemv/int4/main.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class test_col_major_1 {
3939
static constexpr size_t sg_n = 1;
4040
static constexpr size_t sg_k = 512 / sg_m;
4141
static constexpr size_t dequant_s = 128;
42-
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43-
// static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44-
static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO;
42+
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
43+
// static constexpr quant_mode quant_mode = quant_mode::I4_FULLRANGE_NO_ZP;
44+
static constexpr quant_mode quant_mode = quant_mode::I4_ASYM_FP_ZERO;
4545

4646
static constexpr size_t local_kslicing = 1;
4747
static constexpr size_t global_kslicing = 1;
@@ -121,7 +121,7 @@ int gemm_result_validate(
121121
}
122122

123123
template <
124-
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
124+
quant_mode quant_mode = quant_mode::I4_FULLRANGE_NO_ZP,
125125
typename data_type_acc_in = fp16,
126126
typename data_type_b,
127127
typename data_type_scale,
@@ -133,15 +133,15 @@ std::vector<fp16> convert_int4(
133133
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);
134134

135135
int8_t zero_pt_i8;
136-
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
136+
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
137137
zero_pt_i8 = zero_pt & 0xf;
138138
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
139139
int8_t dequant_8bit = data_b & 0xf;
140-
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
140+
if constexpr (quant_mode == quant_mode::I4_FULLRANGE_NO_ZP) {
141141
dequant_fp16[i] = scale * (dequant_8bit - 8);
142-
} else if constexpr (quant_mode == quant_mode::S4_ASYM) {
142+
} else if constexpr (quant_mode == quant_mode::I4_ASYM) {
143143
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144-
} else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
144+
} else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
145145
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
146146
} else {
147147
assert(0);
@@ -154,7 +154,7 @@ std::vector<fp16> convert_int4(
154154
template <
155155
size_t dequant_s,
156156
mem_layout layout_b = mem_layout::col_major,
157-
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
157+
quant_mode quant_mode = quant_mode::I4_FULLRANGE_NO_ZP,
158158
typename data_type_acc_in = fp16,
159159
typename data_type_b,
160160
typename data_type_scale,
@@ -176,13 +176,13 @@ std::vector<data_type_acc_in> dequantize_weight(
176176
for (uint32_t j = 0; j < width; j += step) {
177177
int start_b_in = i * width + j;
178178
int start_scale_in = start_b_in / step;
179-
int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
179+
int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO
180180
? (j / step) * matrix_n + i
181181
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
182182
int start_out =
183183
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184184
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185-
if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
185+
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
186186
zp_value = zp_value >> (4 * (i % pack_radio));
187187
for (uint32_t jj = 0; jj < step; jj++) {
188188
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
@@ -225,7 +225,7 @@ void dequantize_gemv_run(int iter) {
225225
using data_type_b = typename Test::data_type_b;
226226
using data_type_c = typename Test::data_type_c;
227227
using data_type_zero_pt = std::conditional_t<
228-
Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO,
228+
Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO,
229229
data_type_c,
230230
data_type_b>;
231231
using data_type_scale = fp16;
@@ -246,7 +246,7 @@ void dequantize_gemv_run(int iter) {
246246
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
247247
constexpr size_t size_zero_pt_n = matrix_n;
248248
constexpr size_t size_zero_pt =
249-
Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
249+
Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO
250250
? size_zero_pt_k * size_zero_pt_n / 2
251251
: size_zero_pt_k * size_zero_pt_n;
252252

@@ -490,7 +490,7 @@ void dequantize_gemv_run(int iter) {
490490
// It accepts the base pointer to matrix D, and its dimensions
491491
{bias_d, bias_add_shape}});
492492
typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg;
493-
if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
493+
if constexpr (compute_policy::quant_mode == quant_mode::I4_FULLRANGE_NO_ZP) {
494494
gemm_arg =
495495
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
496496
matrix_m,
@@ -508,8 +508,8 @@ void dequantize_gemv_run(int iter) {
508508
Cnt_d,
509509
epilogue_args);
510510
} else if constexpr (
511-
compute_policy::quant_mode == quant_mode::S4_ASYM ||
512-
compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
511+
compute_policy::quant_mode == quant_mode::I4_ASYM ||
512+
compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
513513
gemm_arg =
514514
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
515515
matrix_m,

0 commit comments

Comments
 (0)