Skip to content

Commit

Permalink
update bench
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Dec 5, 2024
1 parent 77f9281 commit 1e0a199
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 49 deletions.
29 changes: 15 additions & 14 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mlx.utils import tree_map
from time_utils import time_fn

L = 16384
L = 65536
H = 32
H_k = H // 4
D = 128
Expand All @@ -27,24 +27,24 @@ def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)


def quant_sdpa(q, k, v):
def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=8
q, *k, *v, scale=1.0, mask=None, bits=bits
)


def quant_attention(q, k, v):
def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]

q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)

scores = mx.quantized_matmul(q, *k, transpose=True)
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)

out = mx.quantized_matmul(scores, *v, transpose=False)
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D))
return out

Expand All @@ -57,12 +57,12 @@ def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)


def time_self_attention_quant_sdpa(q, k, v):
time_fn(quant_sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)


def time_self_attention_quant_primitives(q, k, v):
time_fn(quant_attention, q, k, v)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)


if __name__ == "__main__":
Expand All @@ -72,11 +72,12 @@ def time_self_attention_quant_primitives(q, k, v):
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)

k_quant = mx.quantize(k)
v_quant = mx.quantize(v)
bits = 4
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)

time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
53 changes: 18 additions & 35 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ void quant_sdpa_vector(
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() == 3);
assert(inputs.size() == 3);

auto& s = stream();
auto& d = metal::device(s.device);
Expand Down Expand Up @@ -353,7 +353,7 @@ assert(inputs.size() == 3);

// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto q = copy_unless(is_contiguous, q_pre);

// Donate the query if possible
if (q.is_donatable()) {
Expand All @@ -380,9 +380,20 @@ assert(inputs.size() == 3);
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);

quant_sdpa_vector(
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_, group_size_, bits_);
}
else {
s,
d,
q,
k,
k_scales,
k_biases,
v,
v_scales,
v_biases,
o,
scale_,
group_size_,
bits_);
} else {
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];

Expand All @@ -396,15 +407,10 @@ assert(inputs.size() == 3);
} else {
sdpa_vector(s, d, q, k, v, o, scale_);
}

}

// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
}

// Non-quantized
// Full attention mode
else {
auto& v_pre = inputs[2];

Expand All @@ -430,30 +436,7 @@ assert(inputs.size() == 3);
{str_oB, str_oH, str_oL, str_oD},
flags);

// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);

// Donate the query if possible
if (q.is_donatable()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}

sdpa_vector(s, d, q, k, v, o, scale_);
}
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));

sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}

d.add_temporaries(std::move(copies), s.index);
Expand Down

0 comments on commit 1e0a199

Please sign in to comment.