Skip to content

Commit

Permalink
Merge pull request #38 from arhik/main
Browse files Browse the repository at this point in the history
[docs] update ops/matmul.jl
  • Loading branch information
arhik authored Apr 13, 2024
2 parents 5621e39 + 090f0c0 commit 467fc60
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions src/ops/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
export naive_matmul_kernel, matmul

"""
matmul_heuristics(x, y)
This function computes workgroup size and workgroup count heuristics for a given input.
This is used by `naive_matmul_kernel`.
"""
function matmul_heuristics(x, y)
aSize = size(x)
bSize = size(y)
Expand All @@ -9,6 +14,12 @@ function matmul_heuristics(x, y)
return (outSize, outSize, (1, 1))
end

"""
naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
This is naive matrix multiplication implementation kernel. This is not supposed to be used as a regular
julia function. This needs to be passed to @wgpukernel to under transformations to `WGSL` compatible
shader code.
"""
function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
gIdx = globalId.x
gIdy = globalId.y
Expand All @@ -23,14 +34,24 @@ function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
out[gId] = sum
end

"""
matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
This is wrapper function for end users which uses naive implementation of matrix multiplication
`naive_matmul_kernel` kernel for matrix computation.
"""
function matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
(outSize, wgSize, wgCount) = matmul_heuristics(x, y)
out = WgpuArray{eltype(x), ndims(x)}(undef, outSize)
@wgpukernel launch=true workgroupSizes=wgSize workgroupCount=wgCount shmem=() naive_matmul_kernel(x, y, out)
return out
end


"""
tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
This is compute kernel which carries out tiled matrix multiplication of input `WgpuArrays`. This is
not supposed to be used as a regular julia function. This instead needs to be passed to `@wgpukernel` macro
inside a wrapper function.
"""
function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
#set out matrix to zero
gId = xDims.x*globalId.y + globalId.x
Expand Down Expand Up @@ -61,18 +82,29 @@ function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr

out[gId] = sum
end
# For now valid only for square matrices of size powers of 2 and base size 16.

"""
tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
This function computes workgroup size and workgroup count for a given input for
`tiled_matmul_heuristics` kernel function.
"""
function tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
aSize = size(x)
bSize = size(y)
@assert last(aSize) == first(bSize)
outSize = (first(aSize), last(bSize))
@assert eltype(x) == eltype(y)
# For now valid only for square matrices of size powers of 2 and base size 16.
wgSize = (16, 16) # This can be fixed for now
wgCount = div.((outSize[1], outSize[2]), 16, RoundUp)
return (outSize, wgSize, wgCount)
end

"""
tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
This is user end matrix multiplication function which carries out tiled matrix multiplication of
input `WgpuArray` arguments.
"""
function tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
(outSize, wgSize, wgCount) = tiled_matmul_heuristics(x, y)
out = WgpuArray{eltype(x), ndims(x)}(undef, outSize)
Expand Down

0 comments on commit 467fc60

Please sign in to comment.