@@ -6,18 +6,18 @@ import ClimaCore.Fields: Field
6
6
import ClimaCore. Fields
7
7
import ClimaCore. Spaces
8
8
import ClimaCore. Topologies
9
+ import ClimaCore. MatrixFields
9
10
import ClimaCore. MatrixFields: single_field_solve!
10
11
import ClimaCore. MatrixFields: _single_field_solve!
11
12
import ClimaCore. MatrixFields: band_matrix_solve!, unzip_tuple_field_values
12
13
import ClimaCore. RecursiveApply: ⊠ , ⊞ , ⊟ , rmap, rzero, rdiv
13
14
14
15
function single_field_solve! (device:: ClimaComms.CUDADevice , cache, x, A, b)
15
- Ni, Nj, _, _, Nh = size (Fields. field_values (A))
16
- Ni, Nj, _, _, Nh = size (Fields. field_values (A))
16
+ Ni, Nj, _, Nv, Nh = size (Fields. field_values (A))
17
17
nitems = Ni * Nj * Nh
18
18
nthreads = min (256 , nitems)
19
19
nblocks = cld (nitems, nthreads)
20
- args = (device, cache, x, A, b)
20
+ args = (device, cache, x, A, b, Val (Nv) )
21
21
auto_launch! (
22
22
single_field_solve_kernel!,
23
23
args,
@@ -27,17 +27,26 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
27
27
)
28
28
end
29
29
30
- function single_field_solve_kernel! (device, cache, x, A, b)
30
+ function single_field_solve_kernel! (
31
+ device,
32
+ cache,
33
+ x,
34
+ A,
35
+ b,
36
+ :: Val{Nv} ,
37
+ ) where {Nv}
31
38
idx = CUDA. threadIdx (). x + (CUDA. blockIdx (). x - 1 ) * CUDA. blockDim (). x
32
39
Ni, Nj, _, _, Nh = size (Fields. field_values (A))
33
40
if idx <= Ni * Nj * Nh
34
- i, j, h = Topologies. _get_idx ((Ni, Nj, Nh), idx)
41
+ (i, j, h) = CartesianIndices ((1 : Ni, 1 : Nj, 1 : Nh))[idx]. I
42
+
35
43
_single_field_solve! (
36
44
device,
37
45
Spaces. column (cache, i, j, h),
38
46
Spaces. column (x, i, j, h),
39
47
Spaces. column (A, i, j, h),
40
48
Spaces. column (b, i, j, h),
49
+ Val (Nv),
41
50
)
42
51
end
43
52
return nothing
@@ -49,13 +58,15 @@ function _single_field_solve!(
49
58
x:: Fields.ColumnField ,
50
59
A:: Fields.ColumnField ,
51
60
b:: Fields.ColumnField ,
52
- )
53
- band_matrix_solve! (
61
+ :: Val{Nv} ,
62
+ ) where {Nv}
63
+ band_matrix_solve_local_mem! (
54
64
eltype (A),
55
65
unzip_tuple_field_values (Fields. field_values (cache)),
56
66
Fields. field_values (x),
57
67
unzip_tuple_field_values (Fields. field_values (A. entries)),
58
68
Fields. field_values (b),
69
+ Val (Nv),
59
70
)
60
71
end
61
72
@@ -65,12 +76,12 @@ function _single_field_solve!(
65
76
x:: Fields.ColumnField ,
66
77
A:: UniformScaling ,
67
78
b:: Fields.ColumnField ,
68
- )
79
+ :: Val{Nv} ,
80
+ ) where {Nv}
69
81
x_data = Fields. field_values (x)
70
82
b_data = Fields. field_values (b)
71
- n = length (x_data)
72
- @inbounds for i in 1 : n
73
- x_data[i] = inv (A. λ) ⊠ b_data[i]
83
+ @inbounds for v in 1 : Nv
84
+ x_data[v] = inv (A. λ) ⊠ b_data[v]
74
85
end
75
86
end
76
87
@@ -80,11 +91,95 @@ function _single_field_solve!(
80
91
x:: Fields.PointDataField ,
81
92
A:: UniformScaling ,
82
93
b:: Fields.PointDataField ,
83
- )
94
+ :: Val{Nv} ,
95
+ ) where {Nv}
84
96
x_data = Fields. field_values (x)
85
97
b_data = Fields. field_values (b)
86
- n = length (x_data)
87
- @inbounds begin
88
- x_data[] = inv (A. λ) ⊠ b_data[]
98
+ x_data[] = inv (A. λ) ⊠ b_data[]
99
+ end
100
+
101
+ using StaticArrays: MArray
102
+ function band_matrix_solve_local_mem! (
103
+ t:: Type{<:MatrixFields.TridiagonalMatrixRow} ,
104
+ cache,
105
+ x,
106
+ Aⱼs,
107
+ b,
108
+ :: Val{Nv} ,
109
+ ) where {Nv}
110
+ Ux, U₊₁ = cache
111
+ A₋₁, A₀, A₊₁ = Aⱼs
112
+
113
+ Ux_local = MArray {Tuple{Nv}, eltype(Ux)} (undef)
114
+ U₊₁_local = MArray {Tuple{Nv}, eltype(U₊₁)} (undef)
115
+ x_local = MArray {Tuple{Nv}, eltype(x)} (undef)
116
+ A₋₁_local = MArray {Tuple{Nv}, eltype(A₋₁)} (undef)
117
+ A₀_local = MArray {Tuple{Nv}, eltype(A₀)} (undef)
118
+ A₊₁_local = MArray {Tuple{Nv}, eltype(A₊₁)} (undef)
119
+ b_local = MArray {Tuple{Nv}, eltype(b)} (undef)
120
+ @inbounds for v in 1 : Nv
121
+ A₋₁_local[v] = A₋₁[v]
122
+ A₀_local[v] = A₀[v]
123
+ A₊₁_local[v] = A₊₁[v]
124
+ b_local[v] = b[v]
125
+ end
126
+ cache_local = (Ux_local, U₊₁_local)
127
+ Aⱼs_local = (A₋₁, A₀, A₊₁)
128
+ band_matrix_solve! (t, cache_local, x_local, Aⱼs_local, b_local)
129
+ @inbounds for v in 1 : Nv
130
+ x[v] = x_local[v]
89
131
end
132
+ return nothing
133
+ end
134
+
135
+ function band_matrix_solve_local_mem! (
136
+ t:: Type{<:MatrixFields.PentadiagonalMatrixRow} ,
137
+ cache,
138
+ x,
139
+ Aⱼs,
140
+ b,
141
+ :: Val{Nv} ,
142
+ ) where {Nv}
143
+ Ux, U₊₁, U₊₂ = cache
144
+ A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
145
+ Ux_local = MArray {Tuple{Nv}, eltype(Ux)} (undef)
146
+ U₊₁_local = MArray {Tuple{Nv}, eltype(U₊₁)} (undef)
147
+ U₊₂_local = MArray {Tuple{Nv}, eltype(U₊₂)} (undef)
148
+ x_local = MArray {Tuple{Nv}, eltype(x)} (undef)
149
+ A₋₂_local = MArray {Tuple{Nv}, eltype(A₋₂)} (undef)
150
+ A₋₁_local = MArray {Tuple{Nv}, eltype(A₋₁)} (undef)
151
+ A₀_local = MArray {Tuple{Nv}, eltype(A₀)} (undef)
152
+ A₊₁_local = MArray {Tuple{Nv}, eltype(A₊₁)} (undef)
153
+ A₊₂_local = MArray {Tuple{Nv}, eltype(A₊₂)} (undef)
154
+ b_local = MArray {Tuple{Nv}, eltype(b)} (undef)
155
+ @inbounds for v in 1 : Nv
156
+ A₋₂_local[v] = A₋₂[v]
157
+ A₋₁_local[v] = A₋₁[v]
158
+ A₀_local[v] = A₀[v]
159
+ A₊₁_local[v] = A₊₁[v]
160
+ A₊₂_local[v] = A₊₂[v]
161
+ b_local[v] = b[v]
162
+ end
163
+ cache_local = (Ux_local, U₊₁_local, U₊₂_local)
164
+ Aⱼs_local = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
165
+ band_matrix_solve! (t, cache_local, x_local, Aⱼs_local, b_local)
166
+ @inbounds for v in 1 : Nv
167
+ x[v] = x_local[v]
168
+ end
169
+ return nothing
170
+ end
171
+
172
+ function band_matrix_solve_local_mem! (
173
+ t:: Type{<:MatrixFields.DiagonalMatrixRow} ,
174
+ cache,
175
+ x,
176
+ Aⱼs,
177
+ b,
178
+ :: Val{Nv} ,
179
+ ) where {Nv}
180
+ (A₀,) = Aⱼs
181
+ @inbounds for v in 1 : Nv
182
+ x[v] = inv (A₀[v]) ⊠ b[v]
183
+ end
184
+ return nothing
90
185
end
0 commit comments