Skip to content

Commit

Permalink
update bench
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Oct 23, 2024
1 parent 1d0d438 commit 42a638f
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@ def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)


def quant_sdpa(q, k, v):
def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=8
q, *k, *v, scale=1.0, mask=None, bits=bits
)


def quant_attention(q, k, v):
def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]

q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)

scores = mx.quantized_matmul(q, *k, transpose=True)
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)

out = mx.quantized_matmul(scores, *v, transpose=False)
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D))
return out

Expand All @@ -55,11 +55,11 @@ def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)


def time_self_attention_quant_sdpa(q, k, v):
time_fn(quant_sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)


def time_self_attention_quant_primitives(q, k, v):
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v)


Expand All @@ -70,11 +70,12 @@ def time_self_attention_quant_primitives(q, k, v):
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)

k_quant = mx.quantize(k)
v_quant = mx.quantize(v)
bits = 4
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)

time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

0 comments on commit 42a638f

Please sign in to comment.