Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qbits support f4 weight repack #1653

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#pragma once
#include <ATen/core/TensorBody.h>
#include <torch/torch.h>
#include "bestla/bestla_storage.h"
#include "../include/dispatcher_utils.hpp"
#include <string.h>
#include <assert.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <chrono>
#include <string>
#include "bestla/bestla_device.h"
#include "bestla/bestla_storage.h"
#include "bestla/bestla_utils.h"
#include "bestla/bestla_parallel.h"
namespace dispatcher_utils {
Expand All @@ -26,6 +27,12 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()->
inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); }

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

class qbits_threading {
public:
static bestla::parallel::IThreading* get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
#include "../include/bestla_packq_impl.hpp"

namespace woq {
template <class GemmCore, BTLA_ISA ISA>

template <class proB>
void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
static proB ker;
auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
using WType = typename proB::StorageWeight;
WType qpackw(0);
if constexpr (std::is_same_v<WType, bestla::storage::gemm::StorageWeightKBlockNInteger>) {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
} else {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type));
}
if (p->enable_act_shuffle) ker.enableShuffle(&qpackw);
ctx->packw_size = qpackw.mSize;
if (task == WOQ_GET_PACKW_SIZE) return;
Expand All @@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get());
}

template <class GemmCore, BTLA_ISA ISA>
void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip") {
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
}
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNFloat<GemmCore, ISA>>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: unsupported bestla packq config, compute_type: " + p->compute_type +
" weight_type: " + p->weight_type);
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
switch (dtype) {
case BTLA_DTYPE::F32:
Expand Down Expand Up @@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
// TODO(zhe): elegant impl.
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer WOQ in PACKQ");

if (p->compute_type == "int8") {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer weight-type with int8 compute-type");
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support avx2:", dispatcher_utils::check_avx2());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
if (p->compute_type == "bf16") {
if (dispatcher_utils::check_amx()) {
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
return parse_prob<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ concept quant_PrologueA = requires {
requires !std::is_same_v<T, bestla::utils::bf16>;
};

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

template <class Launcher>
void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start();
Expand Down Expand Up @@ -133,7 +127,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) {
using StorageWeight = typename Launcher::PrologueB::StorageWeight;
size_t asym_size = 0, shuf_size = 0;
int8_t* tmpbuf = nullptr;
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Parallel = bestla::parallel::gemm::SchedulerKBlockS<GemmCore>;
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize);
StorageWeight* packedw = dynamic_cast<StorageWeight*>(ctx->deseries_wei);
Expand Down Expand Up @@ -236,7 +230,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) {
template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class PrologueB,
template <class _T, BTLA_ISA> class PrologueA, template <BTLA_ISA> class Epilogue>
void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Launcher = bestla::wrapper::gemm::LauncherIntKBlock<GemmCore::ISA, GemmCore, PrologueA, PrologueB, Epilogue>;
return execute_task<TASK, Launcher>(p, ctx);
} else {
Expand All @@ -260,7 +254,7 @@ template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class Pro
void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
using namespace bestla::prologue_a::gemm;
if (p->src_dt == dispatcher_utils::QBITS_FP32) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeF32, dispatcher_utils::QBITS_FP32>(
p, ctx);
} else {
Expand All @@ -269,7 +263,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}
if (p->src_dt == dispatcher_utils::QBITS_BF16) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeBf16, dispatcher_utils::QBITS_BF16>(
p, ctx);
} else {
Expand All @@ -289,7 +283,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" ||
p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
if constexpr (!is_int8_cmpt_gemmcore<GemmCore>())
if constexpr (!dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>())
return parse_activation<TASK, GemmCore, WeightKBlockNFloat>(p, ctx);
}
TORCH_CHECK(false,
Expand Down
Loading