-
Notifications
You must be signed in to change notification settings - Fork 10
Description
cutile-python supports promoting vector operands to matrices, and broadcasting batch dimensions. It broadcasts vectors with numpy/torch semantics, so (K,) @ (K, N) → (1, K) @ (K, N) → (1, N) → (N,), whereas we would require an explicit (1, K) * (K, N) for that. However, our vec-mat mul could be an outer product: (M,) @ (1, N) → (M, 1) @ (1, N) → (M, N).
For mat-vec, they have (M, K) @ (K,) → (M, K) @ (K, 1) → (M, 1) → (M,), i.e. append dims from right in the right operand, symmetric with the vec-mat inserting dims from the left. This case works naturally with Julia's semantics of always appending on the right though.
For batch dimensions, cutile-python broadcasts batch dims (all dims before the last two) using _matmul_broadcast_shape and _promote_and_broadcast_to before emitting the MmaFOp. The verifier requires exact batch dim matches, so the broadcast is done at the frontend level. We currently don't broadcast batch dims, they must match exactly, and also happen to be leading at the moment:
# 3D batched matmul: (B, M, K) × (B, K, N) → (B, M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{3}) where {T1}
B = max(size(a, 1), size(b, 1)) # Broadcast batch dimension
M = size(a, 2)
N = size(b, 3)
acc = zeros((B, M, N), T1)
muladd(a, b, acc)
end