Skip to content

Commit 9ac8a16

Browse files
Use local memory in band matrix solve
1 parent db8780f commit 9ac8a16

File tree

5 files changed

+167
-35
lines changed

5 files changed

+167
-35
lines changed

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ NVTX.@annotate function multiple_field_solve!(
2222
b,
2323
x1,
2424
)
25-
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
25+
Ni, Nj, _, Nv, Nh = size(Fields.field_values(x1))
2626
names = MatrixFields.matrix_row_keys(keys(A))
2727
Nnames = length(names)
2828
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh * Nnames)
@@ -38,7 +38,7 @@ NVTX.@annotate function multiple_field_solve!(
3838

3939
device = ClimaComms.device(x[first(names)])
4040

41-
args = (device, caches, xs, As, bs, x1, Val(Nnames))
41+
args = (device, caches, xs, As, bs, x1, Val(Nv), Val(Nnames))
4242

4343
auto_launch!(
4444
multiple_field_solve_kernel!,
@@ -62,9 +62,10 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
6262
i,
6363
j,
6464
h,
65+
::Val{Nv},
6566
iname,
6667
::Val{Nnames},
67-
) where {Nnames}
68+
) where {Nnames, Nv}
6869
return quote
6970
Base.Cartesian.@nif $Nnames ξ -> (iname == ξ) ξ -> begin
7071
_single_field_solve!(
@@ -73,6 +74,7 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
7374
column_A(xs[ξ], i, j, h),
7475
column_A(As[ξ], i, j, h),
7576
column_A(bs[ξ], i, j, h),
77+
Val(Nv),
7678
)
7779
end
7880
end
@@ -85,8 +87,9 @@ function multiple_field_solve_kernel!(
8587
As,
8688
bs,
8789
x1,
90+
::Val{Nv},
8891
::Val{Nnames},
89-
) where {Nnames}
92+
) where {Nnames, Nv}
9093
@inbounds begin
9194
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
9295
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
@@ -102,6 +105,7 @@ function multiple_field_solve_kernel!(
102105
i,
103106
j,
104107
h,
108+
Val(Nv),
105109
iname,
106110
Val(Nnames),
107111
)

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@ import ClimaCore.Fields: Field
66
import ClimaCore.Fields
77
import ClimaCore.Spaces
88
import ClimaCore.Topologies
9+
import ClimaCore.MatrixFields
910
import ClimaCore.MatrixFields: single_field_solve!
1011
import ClimaCore.MatrixFields: _single_field_solve!
1112
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
1213
import ClimaCore.RecursiveApply: , , , rmap, rzero, rdiv
1314

1415
function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
15-
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
16-
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
16+
Ni, Nj, _, Nv, Nh = size(Fields.field_values(A))
1717
nitems = Ni * Nj * Nh
1818
nthreads = min(256, nitems)
1919
nblocks = cld(nitems, nthreads)
20-
args = (device, cache, x, A, b)
20+
args = (device, cache, x, A, b, Val(Nv))
2121
auto_launch!(
2222
single_field_solve_kernel!,
2323
args,
@@ -27,17 +27,26 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
2727
)
2828
end
2929

30-
function single_field_solve_kernel!(device, cache, x, A, b)
30+
function single_field_solve_kernel!(
31+
device,
32+
cache,
33+
x,
34+
A,
35+
b,
36+
::Val{Nv},
37+
) where {Nv}
3138
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
3239
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
3340
if idx <= Ni * Nj * Nh
34-
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
41+
(i, j, h) = CartesianIndices((1:Ni, 1:Nj, 1:Nh))[idx].I
42+
3543
_single_field_solve!(
3644
device,
3745
Spaces.column(cache, i, j, h),
3846
Spaces.column(x, i, j, h),
3947
Spaces.column(A, i, j, h),
4048
Spaces.column(b, i, j, h),
49+
Val(Nv),
4150
)
4251
end
4352
return nothing
@@ -49,13 +58,15 @@ function _single_field_solve!(
4958
x::Fields.ColumnField,
5059
A::Fields.ColumnField,
5160
b::Fields.ColumnField,
52-
)
53-
band_matrix_solve!(
61+
::Val{Nv},
62+
) where {Nv}
63+
band_matrix_solve_local_mem!(
5464
eltype(A),
5565
unzip_tuple_field_values(Fields.field_values(cache)),
5666
Fields.field_values(x),
5767
unzip_tuple_field_values(Fields.field_values(A.entries)),
5868
Fields.field_values(b),
69+
Val(Nv),
5970
)
6071
end
6172

@@ -65,12 +76,12 @@ function _single_field_solve!(
6576
x::Fields.ColumnField,
6677
A::UniformScaling,
6778
b::Fields.ColumnField,
68-
)
79+
::Val{Nv},
80+
) where {Nv}
6981
x_data = Fields.field_values(x)
7082
b_data = Fields.field_values(b)
71-
n = length(x_data)
72-
@inbounds for i in 1:n
73-
x_data[i] = inv(A.λ) b_data[i]
83+
@inbounds for v in 1:Nv
84+
x_data[v] = inv(A.λ) b_data[v]
7485
end
7586
end
7687

@@ -80,11 +91,95 @@ function _single_field_solve!(
8091
x::Fields.PointDataField,
8192
A::UniformScaling,
8293
b::Fields.PointDataField,
83-
)
94+
::Val{Nv},
95+
) where {Nv}
8496
x_data = Fields.field_values(x)
8597
b_data = Fields.field_values(b)
86-
n = length(x_data)
87-
@inbounds begin
88-
x_data[] = inv(A.λ) b_data[]
98+
x_data[] = inv(A.λ) b_data[]
99+
end
100+
101+
using StaticArrays: MArray
102+
function band_matrix_solve_local_mem!(
103+
t::Type{<:MatrixFields.TridiagonalMatrixRow},
104+
cache,
105+
x,
106+
Aⱼs,
107+
b,
108+
::Val{Nv},
109+
) where {Nv}
110+
Ux, U₊₁ = cache
111+
A₋₁, A₀, A₊₁ = Aⱼs
112+
113+
Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
114+
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
115+
x_local = MArray{Tuple{Nv}, eltype(x)}(undef)
116+
A₋₁_local = MArray{Tuple{Nv}, eltype(A₋₁)}(undef)
117+
A₀_local = MArray{Tuple{Nv}, eltype(A₀)}(undef)
118+
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
119+
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
120+
@inbounds for v in 1:Nv
121+
A₋₁_local[v] = A₋₁[v]
122+
A₀_local[v] = A₀[v]
123+
A₊₁_local[v] = A₊₁[v]
124+
b_local[v] = b[v]
125+
end
126+
cache_local = (Ux_local, U₊₁_local)
127+
Aⱼs_local = (A₋₁, A₀, A₊₁)
128+
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
129+
@inbounds for v in 1:Nv
130+
x[v] = x_local[v]
89131
end
132+
return nothing
133+
end
134+
135+
function band_matrix_solve_local_mem!(
136+
t::Type{<:MatrixFields.PentadiagonalMatrixRow},
137+
cache,
138+
x,
139+
Aⱼs,
140+
b,
141+
::Val{Nv},
142+
) where {Nv}
143+
Ux, U₊₁, U₊₂ = cache
144+
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
145+
Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
146+
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
147+
U₊₂_local = MArray{Tuple{Nv}, eltype(U₊₂)}(undef)
148+
x_local = MArray{Tuple{Nv}, eltype(x)}(undef)
149+
A₋₂_local = MArray{Tuple{Nv}, eltype(A₋₂)}(undef)
150+
A₋₁_local = MArray{Tuple{Nv}, eltype(A₋₁)}(undef)
151+
A₀_local = MArray{Tuple{Nv}, eltype(A₀)}(undef)
152+
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
153+
A₊₂_local = MArray{Tuple{Nv}, eltype(A₊₂)}(undef)
154+
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
155+
@inbounds for v in 1:Nv
156+
A₋₂_local[v] = A₋₂[v]
157+
A₋₁_local[v] = A₋₁[v]
158+
A₀_local[v] = A₀[v]
159+
A₊₁_local[v] = A₊₁[v]
160+
A₊₂_local[v] = A₊₂[v]
161+
b_local[v] = b[v]
162+
end
163+
cache_local = (Ux_local, U₊₁_local, U₊₂_local)
164+
Aⱼs_local = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
165+
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
166+
@inbounds for v in 1:Nv
167+
x[v] = x_local[v]
168+
end
169+
return nothing
170+
end
171+
172+
function band_matrix_solve_local_mem!(
173+
t::Type{<:MatrixFields.DiagonalMatrixRow},
174+
cache,
175+
x,
176+
Aⱼs,
177+
b,
178+
::Val{Nv},
179+
) where {Nv}
180+
(A₀,) = Aⱼs
181+
@inbounds for v in 1:Nv
182+
x[v] = inv(A₀[v]) b[v]
183+
end
184+
return nothing
90185
end

src/MatrixFields/field_matrix_solver.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,17 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
247247
end
248248
end
249249

250+
# TODO: we can remove the uniform_vertical_levels
251+
# limitation while still using static shared memory
252+
# once Nv is in the type space.
253+
function uniform_vertical_levels(x, names)
254+
_, _, _, Nv1, _ = size(Fields.field_values(x[first(names)]))
255+
return all(Base.tail(names)) do name
256+
_, _, _, Nv, _ = size(Fields.field_values(x[name]))
257+
Nv == Nv1
258+
end
259+
end
260+
250261
NVTX.@annotate function run_field_matrix_solver!(
251262
::BlockDiagonalSolve,
252263
cache,
@@ -256,7 +267,8 @@ NVTX.@annotate function run_field_matrix_solver!(
256267
)
257268
names = matrix_row_keys(keys(A))
258269
if length(names) == 1 ||
259-
all(name -> A[name, name] isa UniformScaling, names.values)
270+
all(name -> A[name, name] isa UniformScaling, names.values) ||
271+
!uniform_vertical_levels(x, names.values)
260272
foreach(names) do name
261273
single_field_solve!(cache[name], x[name], A[name, name], b[name])
262274
end

src/MatrixFields/single_field_solver.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,26 @@ function single_field_solver_cache(A::ColumnwiseBandMatrixField, b)
4545
return similar(b, cache_eltype)
4646
end
4747

48+
function single_field_solve_diag_matrix_row!(
49+
cache,
50+
x,
51+
A::ColumnwiseBandMatrixField,
52+
b,
53+
)
54+
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
55+
b_vals = Fields.field_values(b)
56+
x_vals = Fields.field_values(x)
57+
(A₀,) = Aⱼs
58+
@. x_vals = inv(A₀) b_vals
59+
end
4860
single_field_solve!(_, x, A::UniformScaling, b) = x .= inv(A.λ) .* b
49-
single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) =
50-
single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b)
61+
function single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b)
62+
if eltype(A) <: MatrixFields.DiagonalMatrixRow
63+
single_field_solve_diag_matrix_row!(cache, x, A, b)
64+
else
65+
single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b)
66+
end
67+
end
5168

5269
single_field_solve!(::ClimaComms.AbstractCPUDevice, cache, x, A, b) =
5370
_single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b)
@@ -86,14 +103,6 @@ function _single_field_solve_col!(
86103
end
87104
end
88105

89-
_single_field_solve!(
90-
cache::Fields.Field,
91-
x::Fields.Field,
92-
A::Union{Fields.Field, UniformScaling},
93-
b::Fields.Field,
94-
dev::ClimaComms.AbstractCPUDevice,
95-
) = _single_field_solve_col!(dev, cache, x, A, b)
96-
97106
unzip_tuple_field_values(data) =
98107
ntuple(i -> data.:($i), Val(length(propertynames(data))))
99108

test/MatrixFields/field_matrix_solvers.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Revise; include(joinpath("test", "MatrixFields", "field_matrix_solvers.jl"
55
import Logging
66
import Logging: Debug
77
import LinearAlgebra: I, norm
8+
import ClimaComms
89
import ClimaCore.Utilities: half
910
import ClimaCore.RecursiveApply:
1011
import ClimaCore.MatrixFields: @name
@@ -21,8 +22,16 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
2122
solver = FieldMatrixSolver(alg, A, b)
2223
args = (solver, x, A, b)
2324

24-
solve_time = @benchmark field_matrix_solve!(args...)
25-
mul_time = @benchmark field_matrix_mul!(b_test, A, x)
25+
solve_time =
26+
@benchmark ClimaComms.@cuda_sync comms_device field_matrix_solve!(
27+
args...,
28+
)
29+
mul_time =
30+
@benchmark ClimaComms.@cuda_sync comms_device field_matrix_mul!(
31+
b_test,
32+
A,
33+
x,
34+
)
2635

2736
solve_time_rounded = round(solve_time; sigdigits = 2)
2837
mul_time_rounded = round(mul_time; sigdigits = 2)
@@ -58,11 +67,14 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
5867
AnyFrameModule(MatrixFields.KrylovKit),
5968
AnyFrameModule(Base.CoreLogging),
6069
)
61-
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
62-
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
70+
using_cuda ||
71+
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
72+
using_cuda ||
73+
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
6374
@test_opt ignored_modules = ignored field_matrix_mul!(b, A, x)
6475

65-
using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0
76+
# TODO: fix broken test when Nv is added to the type space
77+
using_cuda || @test @allocated(field_matrix_solve!(args...)) 1536
6678
using_cuda || @test @allocated(field_matrix_mul!(b, A, x)) == 0
6779
end
6880
end

0 commit comments

Comments
 (0)