Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Dec 5, 2024
1 parent a94c06d commit 77f9281
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 65 deletions.
61 changes: 28 additions & 33 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn

L = 16384
Expand All @@ -11,10 +12,6 @@


def attention(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
Expand All @@ -27,21 +24,31 @@ def attention(q, k, v):


def sdpa(q, k, v):
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
k = mx.dequantize(*k, bits=8)
v = mx.dequantize(*v, bits=8)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)


def quant_sdpa(q, k, v):
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=8
)


def quant_attention(q, k, v):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]

q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)

scores = mx.quantized_matmul(q, *k, transpose=True)
scores = mx.softmax(scores, axis=-1)

out = mx.quantized_matmul(scores, *v, transpose=False)
out = out.reshape((B, Hq, L, D))
return out


def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)

Expand All @@ -54,34 +61,22 @@ def time_self_attention_quant_sdpa(q, k, v):
time_fn(quant_sdpa, q, k, v)


def time_self_attention_quant_primitives(q, k, v):
time_fn(quant_attention, q, k, v)


if __name__ == "__main__":
mx.random.seed(3)
# q = mx.random.uniform(shape=(1, H, 1, D))
# k = mx.random.uniform(shape=(1, H_k, L, D))
# v = mx.random.uniform(shape=(1, H_k, L, D))
q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy"))
k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy"))
v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy"))
print(q.dtype)
print(q.shape, k.shape, v.shape)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)

k_quant = mx.quantize(k)
v_quant = mx.quantize(v)
mx.eval(k_quant, v_quant)

# time_self_attention_sdpa(q, k, v)
# time_self_attention_quant_sdpa(q, k_quant, v_quant)
# time_self_attention_primitives(q, k, v)
q_sdpa = quant_sdpa(q, k, v)
print(q_sdpa)
# o_attention = attention(q, k, v)
# print(o_attention)
# np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
o_sdpa = sdpa(q, k, v)
print(o_sdpa)
np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
# print(o_sdpa[..., :64])
# print()
# print(o_attention[..., :64])
# np.testing.assert_allclose(o_sdpa, o_attention)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant)
19 changes: 1 addition & 18 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"

#include <iostream>

namespace mlx::core::fast {

namespace {
Expand Down Expand Up @@ -364,21 +362,6 @@ assert(inputs.size() == 3);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}

quant_sdpa_vector(
s,
d,
q,
k,
k_scales,
k_biases,
v,
v_scales,
v_biases,
o,
scale_,
group_size_,
bits_);

if (quantized_) {
auto& k_scales_pre = inputs[2];
auto& k_biases_pre = inputs[3];
Expand All @@ -397,7 +380,7 @@ assert(inputs.size() == 3);
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);

quant_sdpa_vector(
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_);
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_, group_size_, bits_);
}
else {
auto& k_pre = inputs[1];
Expand Down
2 changes: 0 additions & 2 deletions mlx/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#include "mlx/ops.h"
#include "mlx/transforms.h"

#include <iostream>

namespace mlx::core::fast {

std::vector<array> Custom::vjp(
Expand Down
20 changes: 8 additions & 12 deletions python/src/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,22 @@ void init_fast(nb::module_& parent_module) {
nb::sig(
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports:
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of
the input precision.
A fast implementation of multi-head attention where the keys and values are quantized.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
see :func:`scaled_dot_product_attention` for more details.
Args:
q (array): Input query array.
k (array): Input keys array.
k_scales (array): Scales for the quantized keys array.
k_biases (array): Biases for the quantized keys array.
v (array): Input values array.
v_scales (array): Scales for the quantized values array.
v_biases (array): Biases for the quantized values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
group_size (int): The group size used in the KV quantization.
bits (int): The bits used in the KV quantization.
Returns:
array: The output array.
)pbdoc");
Expand Down

0 comments on commit 77f9281

Please sign in to comment.