Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MultiBroadcastFusion #1641

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@

import ClimaCore.DataLayouts: AbstractData
import ClimaCore.DataLayouts: FusedMultiBroadcast
import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF
import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle
import ClimaCore.DataLayouts: promote_parent_array_type
import ClimaCore.DataLayouts: parent_array_type
import ClimaCore.DataLayouts: device_from_array_type, isascalar
import ClimaCore.DataLayouts: fused_copyto!
import Adapt
import CUDA

device_from_array_type(::Type{<:CUDA.CuArray}) = ClimaComms.CUDADevice()

parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
CUDA.CuArray{T, N, B} where {N}

Expand Down Expand Up @@ -180,3 +186,83 @@ function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray}
)
return dest
end

Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:AbstractData, <:Any},
I,
v,
)
dest, bc = pair.first, pair.second
if v <= size(dest, 4)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, v)
rcopyto_at!(first(pairs), I, v)
rcopyto_at!(Base.tail(pairs), I, v)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, v) =
rcopyto_at!(first(pairs), I, v)
@inline rcopyto_at!(pairs::Tuple{}, I, v) = nothing

function knl_fused_copyto!(fmbc::FusedMultiBroadcast)

@inbounds begin
i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
(; pairs) = fmbc
I = CartesianIndex((i, j, 1, v, h))
rcopyto_at!(pairs, I, v)
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nij},
::ClimaComms.CUDADevice,
) where {S, Nij}
_, _, _, Nv, Nh = size(dest1)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
args = (fmbc,)
auto_launch!(
knl_fused_copyto!,
args,
dest1;
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
end
return nothing
end

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
)
FusedMultiBroadcast(
map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(
Adapt.adapt(to, dest),
Base.Broadcast.Broadcasted(
bc.style,
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
),
)
end,
)
end
8 changes: 8 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module DataLayouts
import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms
import MultiBroadcastFusion as MBF
import Adapt

import ..slab, ..slab_args, ..column, ..column_args, ..level
Expand Down Expand Up @@ -1451,4 +1452,11 @@ Adapt.adapt_structure(to, data::VF{S}) where {S} =
Adapt.adapt_structure(to, data::DataF{S}) where {S} =
DataF{S}(Adapt.adapt(to, parent(data)))

# TODO: Should the DataLayout be device-aware? So that we can
# determine if we're multi-threaded or not?
# This is only currently used in FusedMultiBroadcast kernels
device_from_array_type(::Type{<:AbstractArray}) = ClimaComms.CPUSingleThreaded()
ClimaComms.device(data::AbstractData) =
device_from_array_type(typeof(parent(data)))

end # module
101 changes: 101 additions & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_direct

# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
# via https://github.com/CliMA/MultiBroadcastFusion.jl
MBF.@make_type FusedMultiBroadcast
MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct

# Broadcasting of AbstractData objects
# https://docs.julialang.org/en/v1/manual/interfaces/#Broadcast-Styles

Expand Down Expand Up @@ -587,3 +595,96 @@ function Base.copyto!(
) where {S, Nij, A}
return _serial_copyto!(dest, bc)
end

# ============= FusedMultiBroadcast

isascalar(
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(bc) = false


# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmbc, dest1, ClimaComms.device(dest1))
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nij},
::ClimaComms.AbstractCPUDevice,
) where {S1, Nij}
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
# Base.copyto!(dest, bc) # we can just fall back like this
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
I = CartesianIndex(i, j, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Ni, A},
::ClimaComms.AbstractCPUDevice,
) where {S, Ni, A}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv
I = CartesianIndex(i, 1, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S1, A},
::ClimaComms.AbstractCPUDevice,
) where {S1, A}
_, _, _, Nv, _ = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv
I = CartesianIndex(1, 1, 1, v, 1)
dest[I] = convert(eltype(dest), bc[I])
end
end
return nothing
end

# we've already diagonalized dest, so we only need to make
# sure that all the broadcast axes are compatible.
# Logic here is similar to Base.Broadcast.instantiate
@inline function _check_fused_broadcast_axes(bc1, bc2)
axes = Base.Broadcast.combine_axes(bc1.args..., bc2.args...)
if !(axes isa Nothing)
Base.Broadcast.check_broadcast_axes(axes, bc1.args...)
Base.Broadcast.check_broadcast_axes(axes, bc2.args...)
end
end

@inline check_fused_broadcast_axes(fmbc::FusedMultiBroadcast) =
check_fused_broadcast_axes(
map(x -> x.second, fmbc.pairs),
first(fmbc.pairs).second,
)
@inline check_fused_broadcast_axes(bcs::Tuple{<:Any}, bc1) =
_check_fused_broadcast_axes(first(bcs), bc1)
@inline check_fused_broadcast_axes(bcs::Tuple{}, bc1) = nothing
@inline function check_fused_broadcast_axes(bcs::Tuple, bc1)
_check_fused_broadcast_axes(first(bcs), bc1)
check_fused_broadcast_axes(Base.tail(bcs), bc1)
end
9 changes: 8 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ module Fields
import ClimaComms
import MultiBroadcastFusion as MBF
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts: DataLayouts, AbstractData, DataStyle
import ..DataLayouts:
DataLayouts,
AbstractData,
DataStyle,
FusedMultiBroadcast,
@fused_direct,
isascalar,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
import ..Quadratures
Expand Down
39 changes: 39 additions & 0 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,45 @@ end
return dest
end

# Fused multi-broadcast entry point for Fields
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(field_values(pair.first), bc′)
end,
)
check_mismatched_spaces(fmbc)
check_fused_broadcast_axes(fmbc)
Base.copyto!(fmb_data) # forward to DataLayouts
end

@inline check_mismatched_spaces(fmbc::FusedMultiBroadcast) =
check_mismatched_spaces(
map(x -> axes(x.first), fmbc.pairs),
axes(first(fmbc.pairs).first),
)
@inline check_mismatched_spaces(axs::Tuple{<:Any}, ax1) =
_check_mismatched_spaces(first(axs), ax1)
@inline check_mismatched_spaces(axs::Tuple{}, ax1) = nothing
@inline function check_mismatched_spaces(axs::Tuple, ax1)
_check_mismatched_spaces(first(axs), ax1)
check_mismatched_spaces(Base.tail(axs), ax1)
end

_check_mismatched_spaces(::T, ::T) where {T <: AbstractSpace} = nothing
_check_mismatched_spaces(space1, space2) =
error("FusedMultiBroadcast spaces are not the same.")

@noinline function error_mismatched_spaces(space1::Type, space2::Type)
error("Broacasted spaces are not the same.")
end
Expand Down
5 changes: 5 additions & 0 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#=
julia --check-bounds=yes --project=test
julia --project=test
using Revise; include(joinpath("test", "Fields", "field.jl"))
=#
Expand Down Expand Up @@ -915,3 +916,7 @@ end
end
nothing
end

include("field_multi_broadcast_fusion.jl")

nothing
Loading
Loading