From 77f92815dc9b750dd728ca8a9bcef40841470814 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 22 Oct 2024 20:13:32 -0700 Subject: [PATCH] clean --- benchmarks/python/sdpa_vector_bench.py | 61 +++++++++---------- .../metal/scaled_dot_product_attention.cpp | 19 +----- mlx/fast.cpp | 2 - python/src/fast.cpp | 20 +++--- 4 files changed, 37 insertions(+), 65 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 8402833fd..936514644 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,5 +1,6 @@ import mlx.core as mx import numpy as np +from mlx.utils import tree_map from time_utils import time_fn L = 16384 @@ -11,10 +12,6 @@ def attention(q, k, v): - k = mx.quantize(k) - v = mx.quantize(v) - k = mx.dequantize(*k) - v = mx.dequantize(*v) B, Hq, L, D = q.shape _, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D) @@ -27,21 +24,31 @@ def attention(q, k, v): def sdpa(q, k, v): - k = mx.quantize(k, bits=8) - v = mx.quantize(v, bits=8) - k = mx.dequantize(*k, bits=8) - v = mx.dequantize(*v, bits=8) return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) def quant_sdpa(q, k, v): - k = mx.quantize(k, bits=8) - v = mx.quantize(v, bits=8) return mx.fast.quantized_scaled_dot_product_attention( q, *k, *v, scale=1.0, mask=None, bits=8 ) +def quant_attention(q, k, v): + B, Hq, L, D = q.shape + Hk = k[0].shape[1] + + q = q.reshape((B, Hk, Hq // Hk, L, D)) + k = tree_map(lambda x: mx.expand_dims(x, axis=2), k) + v = tree_map(lambda x: mx.expand_dims(x, axis=2), v) + + scores = mx.quantized_matmul(q, *k, transpose=True) + scores = mx.softmax(scores, axis=-1) + + out = mx.quantized_matmul(scores, *v, transpose=False) + out = out.reshape((B, Hq, L, D)) + return out + + def time_self_attention_primitives(q, k, v): time_fn(attention, q, k, v) @@ -54,34 +61,22 @@ def time_self_attention_quant_sdpa(q, k, v): time_fn(quant_sdpa, q, k, v) +def time_self_attention_quant_primitives(q, k, v): + time_fn(quant_attention, q, k, v) + + if __name__ == "__main__": mx.random.seed(3) - # q = mx.random.uniform(shape=(1, H, 1, D)) - # k = mx.random.uniform(shape=(1, H_k, L, D)) - # v = mx.random.uniform(shape=(1, H_k, L, D)) - q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy")) - k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy")) - v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy")) - print(q.dtype) - print(q.shape, k.shape, v.shape) + q = mx.random.uniform(shape=(1, H, 1, D)) + k = mx.random.uniform(shape=(1, H_k, L, D)) + v = mx.random.uniform(shape=(1, H_k, L, D)) mx.eval(q, k, v) k_quant = mx.quantize(k) v_quant = mx.quantize(v) mx.eval(k_quant, v_quant) - # time_self_attention_sdpa(q, k, v) - # time_self_attention_quant_sdpa(q, k_quant, v_quant) - # time_self_attention_primitives(q, k, v) - q_sdpa = quant_sdpa(q, k, v) - print(q_sdpa) - # o_attention = attention(q, k, v) - # print(o_attention) - # np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5) - o_sdpa = sdpa(q, k, v) - print(o_sdpa) - np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5) - # print(o_sdpa[..., :64]) - # print() - # print(o_attention[..., :64]) - # np.testing.assert_allclose(o_sdpa, o_attention) + time_self_attention_sdpa(q, k, v) + time_self_attention_quant_sdpa(q, k_quant, v_quant) + time_self_attention_primitives(q, k, v) + time_self_attention_quant_primitives(q, k_quant, v_quant) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 80b67ecb7..c66b33a00 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -12,8 +12,6 @@ #include "mlx/fast_primitives.h" #include "mlx/utils.h" -#include - namespace mlx::core::fast { namespace { @@ -364,21 +362,6 @@ assert(inputs.size() == 3); o.set_data(allocator::malloc_or_wait(o.nbytes())); } - quant_sdpa_vector( - s, - d, - q, - k, - k_scales, - k_biases, - v, - v_scales, - v_biases, - o, - scale_, - group_size_, - bits_); - if (quantized_) { auto& k_scales_pre = inputs[2]; auto& k_biases_pre = inputs[3]; @@ -397,7 +380,7 @@ assert(inputs.size() == 3); auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); quant_sdpa_vector( - s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_); + s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_, group_size_, bits_); } else { auto& k_pre = inputs[1]; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index bbeeefebc..37a6ec47b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -10,8 +10,6 @@ #include "mlx/ops.h" #include "mlx/transforms.h" -#include - namespace mlx::core::fast { std::vector Custom::vjp( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index d2617d28b..188a93b22 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -180,26 +180,22 @@ void init_fast(nb::module_& parent_module) { nb::sig( "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. - - Supports: - - * `Multi-Head Attention `_ - * `Grouped Query Attention `_ - * `Multi-Query Attention `_ - - Note: The softmax operation is performed in ``float32`` regardless of - the input precision. + A fast implementation of multi-head attention where the keys and values are quantized. - Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` - and ``v`` inputs should not be pre-tiled to match ``q``. + see :func:`scaled_dot_product_attention` for more details. Args: q (array): Input query array. k (array): Input keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array): Biases for the quantized keys array. v (array): Input values array. + v_scales (array): Scales for the quantized values array. + v_biases (array): Biases for the quantized values array. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) mask (array, optional): An additive mask to apply to the query-key scores. + group_size (int): The group size used in the KV quantization. + bits (int): The bits used in the KV quantization. Returns: array: The output array. )pbdoc");