diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 936514644b..ac58572cf6 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -3,7 +3,7 @@ from mlx.utils import tree_map from time_utils import time_fn -L = 16384 +L = 65536 H = 32 H_k = H // 4 D = 128 @@ -27,13 +27,13 @@ def sdpa(q, k, v): return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) -def quant_sdpa(q, k, v): +def quant_sdpa(q, k, v, bits=4): return mx.fast.quantized_scaled_dot_product_attention( - q, *k, *v, scale=1.0, mask=None, bits=8 + q, *k, *v, scale=1.0, mask=None, bits=bits ) -def quant_attention(q, k, v): +def quant_attention(q, k, v, bits=4): B, Hq, L, D = q.shape Hk = k[0].shape[1] @@ -41,10 +41,10 @@ def quant_attention(q, k, v): k = tree_map(lambda x: mx.expand_dims(x, axis=2), k) v = tree_map(lambda x: mx.expand_dims(x, axis=2), v) - scores = mx.quantized_matmul(q, *k, transpose=True) + scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits) scores = mx.softmax(scores, axis=-1) - out = mx.quantized_matmul(scores, *v, transpose=False) + out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits) out = out.reshape((B, Hq, L, D)) return out @@ -57,12 +57,12 @@ def time_self_attention_sdpa(q, k, v): time_fn(sdpa, q, k, v) -def time_self_attention_quant_sdpa(q, k, v): - time_fn(quant_sdpa, q, k, v) +def time_self_attention_quant_sdpa(q, k, v, bits=4): + time_fn(quant_sdpa, q, k, v, bits) -def time_self_attention_quant_primitives(q, k, v): - time_fn(quant_attention, q, k, v) +def time_self_attention_quant_primitives(q, k, v, bits=4): + time_fn(quant_attention, q, k, v, bits) if __name__ == "__main__": @@ -72,11 +72,12 @@ def time_self_attention_quant_primitives(q, k, v): v = mx.random.uniform(shape=(1, H_k, L, D)) mx.eval(q, k, v) - k_quant = mx.quantize(k) - v_quant = mx.quantize(v) + bits = 4 + k_quant = mx.quantize(k, bits=bits) + v_quant = mx.quantize(v, bits=bits) mx.eval(k_quant, v_quant) time_self_attention_sdpa(q, k, v) - time_self_attention_quant_sdpa(q, k_quant, v_quant) + time_self_attention_quant_sdpa(q, k_quant, v_quant, bits) time_self_attention_primitives(q, k, v) - time_self_attention_quant_primitives(q, k_quant, v_quant) + time_self_attention_quant_primitives(q, k_quant, v_quant, bits) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c66b33a002..4965dacb4f 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -306,7 +306,7 @@ void quant_sdpa_vector( void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { -assert(inputs.size() == 3); + assert(inputs.size() == 3); auto& s = stream(); auto& d = metal::device(s.device); @@ -353,7 +353,7 @@ assert(inputs.size() == 3); // We are in vector mode ie single query if (q_pre.shape(2) == 1) { - auto q = copy_unless(is_contiguous, q_pre); + auto q = copy_unless(is_contiguous, q_pre); // Donate the query if possible if (q.is_donatable()) { @@ -380,9 +380,20 @@ assert(inputs.size() == 3); auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); quant_sdpa_vector( - s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_, group_size_, bits_); - } - else { + s, + d, + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + o, + scale_, + group_size_, + bits_); + } else { auto& k_pre = inputs[1]; auto& v_pre = inputs[2]; @@ -396,15 +407,10 @@ assert(inputs.size() == 3); } else { sdpa_vector(s, d, q, k, v, o, scale_); } - } - - // We route to the 2 pass fused attention if - // - The device is large and the sequence length long - // - The sequence length is even longer and we have gqa } - // Non-quantized + // Full attention mode else { auto& v_pre = inputs[2]; @@ -430,30 +436,7 @@ assert(inputs.size() == 3); {str_oB, str_oH, str_oL, str_oD}, flags); - // We are in vector mode ie single query - if (q_pre.shape(2) == 1) { - auto q = copy_unless(is_contiguous, q_pre); - auto k = copy_unless(is_contiguous_except_seq_len, k_pre); - auto v = copy_unless(is_contiguous_except_seq_len, v_pre); - - // Donate the query if possible - if (q.is_donatable()) { - o.move_shared_buffer(q); - } else { - o.set_data(allocator::malloc_or_wait(o.nbytes())); - } - - sdpa_vector(s, d, q, k, v, o, scale_); - } - // Full attention mode - else { - auto q = copy_unless(is_matrix_contiguous, q_pre); - auto k = copy_unless(is_matrix_contiguous, k_pre); - auto v = copy_unless(is_matrix_contiguous, v_pre); - o.set_data(allocator::malloc_or_wait(o.nbytes())); - - sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); - } + sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); } d.add_temporaries(std::move(copies), s.index);