Skip to content

Matmul broadcasting #115

@AntonOresten

Description

@AntonOresten

cutile-python supports promoting vector operands to matrices, and broadcasting batch dimensions. It broadcasts vectors with numpy/torch semantics, so (K,) @ (K, N) → (1, K) @ (K, N) → (1, N) → (N,), whereas we would require an explicit (1, K) * (K, N) for that. However, our vec-mat mul could be an outer product: (M,) @ (1, N) → (M, 1) @ (1, N) → (M, N).

For mat-vec, they have (M, K) @ (K,) → (M, K) @ (K, 1) → (M, 1) → (M,), i.e. append dims from right in the right operand, symmetric with the vec-mat inserting dims from the left. This case works naturally with Julia's semantics of always appending on the right though.

For batch dimensions, cutile-python broadcasts batch dims (all dims before the last two) using _matmul_broadcast_shape and _promote_and_broadcast_to before emitting the MmaFOp. The verifier requires exact batch dim matches, so the broadcast is done at the frontend level. We currently don't broadcast batch dims, they must match exactly, and also happen to be leading at the moment:

# 3D batched matmul: (B, M, K) × (B, K, N) → (B, M, N)
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{3}) where {T1}
    B = max(size(a, 1), size(b, 1))  # Broadcast batch dimension
    M = size(a, 2)
    N = size(b, 3)
    acc = zeros((B, M, N), T1)
    muladd(a, b, acc)
end

https://github.com/NVIDIA/cutile-python/blob/eaad4888ea77c2eb470efca2f76a03c100518b63/src/cuda/tile/_ir/ops.py#L2689

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions