Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement I-quants (IQ4XS, IQ4NL) #2785

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions candle-core/benches/benchmarks/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ fn run_reduce<T: candle_core::FloatDType>(
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
Tensor::rand(lo, up, (b, m, k), device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
Tensor::rand(lo, up, (b, m, k), device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();
Expand Down Expand Up @@ -105,12 +105,12 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
Tensor::rand(lo, up, (b, m, k), device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
Tensor::rand(lo, up, (b, m, k), device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();
Expand Down
2 changes: 1 addition & 1 deletion candle-core/benches/benchmarks/where_cond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;

const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
static DATA: [u8; SIZE] = create_cond_arr::<SIZE>();

fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
Expand Down
115 changes: 115 additions & 0 deletions candle-core/src/quantized/avx.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::iq_quants::{BlockIQ4nl, BlockIQ4xs, KVALUES_IQ4NL, QK4_NL};
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
Expand Down Expand Up @@ -47,6 +48,12 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
mul_sum_us8_pairs_float(ax, sy)
}

pub(crate) unsafe fn mul_add_epi8(x: __m256i, y: __m256i) -> __m256i {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
_mm256_maddubs_epi16(ax, sy)
}

#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
Expand Down Expand Up @@ -665,3 +672,111 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc))
}
}

#[inline(always)]
pub(crate) fn vec_dot_iq4_nl_q8_0(n: usize, xs: &[BlockIQ4nl], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK4_NL;
if n % qk != 0 {
crate::bail!("vec_dot_iq4_nl_q8_0: {n} is not divisible by {qk}")
}

unsafe {
let values128 = _mm_loadu_si128(KVALUES_IQ4NL.as_ptr() as *const __m128i);
let m4b = _mm_set1_epi8(0x0f);
let mone = _mm256_set1_epi16(1);

let mut accum1 = _mm256_setzero_ps();
let mut accum2 = _mm256_setzero_ps();
let nb = n / QK4_NL;
for ib in (0..nb - 1).step_by(2) {
let q4bits_1 = _mm_loadu_si128(xs[ib].qs.as_ptr() as *const __m128i);
let q4bits_2 = _mm_loadu_si128(xs[ib + 1].qs.as_ptr() as *const __m128i);
let q8b_1 = _mm256_loadu_si256(ys[ib].qs.as_ptr() as *const __m256i);
let q8b_2 = _mm256_loadu_si256(ys[ib + 1].qs.as_ptr() as *const __m256i);
let q4b_1 = mm256_set_m128i(
_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)),
);
let q4b_2 = mm256_set_m128i(
_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)),
);
let p16_1 = mul_add_epi8(q4b_1, q8b_1);
let p16_2 = mul_add_epi8(q4b_2, q8b_2);
let p_1 = _mm256_madd_epi16(p16_1, mone);
let p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(
_mm256_set1_ps(ys[ib].d.to_f32() * xs[ib].d.to_f32()),
_mm256_cvtepi32_ps(p_1),
accum1,
);
accum2 = _mm256_fmadd_ps(
_mm256_set1_ps(ys[ib + 1].d.to_f32() * xs[ib + 1].d.to_f32()),
_mm256_cvtepi32_ps(p_2),
accum2,
);
}
Ok(hsum_float_8(_mm256_add_ps(accum1, accum2)))
}
}

#[inline(always)]
pub(crate) fn vec_dot_iq4_xs_q8k(n: usize, xs: &[BlockIQ4xs], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_iq4_xs_q8k: {n} is not divisible by {qk}")
}

unsafe {
let values128 = _mm_loadu_si128(KVALUES_IQ4NL.as_ptr() as *const __m128i);
let m4b = _mm_set1_epi8(0x0f);

let mut accum = _mm256_setzero_ps();

let nb = n / QK_K;
for ibl in 0..nb {
let mut qs = xs[ibl].qs.as_ptr();
let mut q8 = ys[ibl].qs.as_ptr();
let mut sh = xs[ibl].scales_h;

let mut sumi1 = _mm256_setzero_si256();
let mut sumi2 = _mm256_setzero_si256();
for ib in (0..QK_K / 32).step_by(2) {
let q4bits_1 = _mm_loadu_si128(qs as *const __m128i);
qs = qs.add(16);
let q4bits_2 = _mm_loadu_si128(qs as *const __m128i);
qs = qs.add(16);
let q8b_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8b_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q4b_1 = mm256_set_m128i(
_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)),
);
let q4b_2 = mm256_set_m128i(
_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)),
);
let p16_1 = mul_add_epi8(q4b_1, q8b_1);
let p16_2 = mul_add_epi8(q4b_2, q8b_2);
let ls1 =
((xs[ibl].scales_l[ib / 2] as i16 & 0xf) | ((sh << 4) & 0x30) as i16) - 32;
let ls2 = ((xs[ibl].scales_l[ib / 2] as i16 >> 4) | ((sh << 2) & 0x30) as i16) - 32;
sh >>= 4;
let p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
let p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
sumi1 = _mm256_add_epi32(p_1, sumi1);
sumi2 = _mm256_add_epi32(p_2, sumi2);
}

accum = _mm256_fmadd_ps(
_mm256_set1_ps(xs[ibl].d.to_f32() * ys[ibl].d),
_mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)),
accum,
);
}

Ok(hsum_float_8(accum))
}
}
3 changes: 3 additions & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ impl QCudaStorage {
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
GgmlDType::Iq4Xs | GgmlDType::Iq4Nl => {
todo!("CUDA does not support Iq4Xs or Iq4Nl yet.")
}
}

self.device
Expand Down
Loading
Loading