Skip to content

Commit

Permalink
Use PrecompileTools
Browse files Browse the repository at this point in the history
  • Loading branch information
lcw committed May 9, 2024
1 parent 9dececf commit 52e0cc1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 17 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
OneDimensionalNodes = "c5182250-406c-41f5-b9da-836c94d3c2ab"
P4estTypes = "f636fe8e-398d-42a9-9d15-dd2c0670d30f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -36,6 +37,7 @@ MPI = "0.20"
OneDimensionalNodes = "1"
P4estTypes = "0.1.3"
RecipesBase = "1"
PrecompileTools = "1"
Requires = "1"
SparseArrays = "1"
StaticArrays = "1"
Expand Down
19 changes: 15 additions & 4 deletions ext/RavenCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
module RavenCUDAExt

import Raven
import Adapt
import MPI
import StaticArrays
import PrecompileTools

PrecompileTools.@recompile_invalidations begin
import Raven
import Adapt
import MPI
import StaticArrays
end

isdefined(Base, :get_extension) ? (using CUDA) : (using ..CUDA)
isdefined(Base, :get_extension) ? (using CUDA.CUDAKernels) : (using ..CUDA.CUDAKernels)
Expand Down Expand Up @@ -66,4 +70,11 @@ CUDA.@device_override function Base.checkbounds(A::StaticArrays.MArray, I...)
nothing
end

PrecompileTools.@compile_workload let
for FT in (Float32, Float64)
AT = CuArray
Raven.precompile_workload(FT, AT)
end
end

end # module RavenCUDAExt
38 changes: 25 additions & 13 deletions src/Raven.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
module Raven

using Adapt
using Compat
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions.Extras: @unroll
using LinearAlgebra
using MPI
using OneDimensionalNodes
import P4estTypes
using RecipesBase
using StaticArrays
using StaticArrays: tuple_prod, tuple_length, size_to_tuple
using SparseArrays
using PrecompileTools: @compile_workload, @recompile_invalidations

@recompile_invalidations begin
using Adapt
using Compat
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions.Extras: @unroll
using LinearAlgebra
using MPI
using OneDimensionalNodes
import P4estTypes
using RecipesBase
using StaticArrays
using StaticArrays: tuple_prod, tuple_length, size_to_tuple
using SparseArrays
end

export LobattoCell, GaussCell

Expand Down Expand Up @@ -60,6 +64,7 @@ include("grids.jl")
include("gridmanager.jl")
include("gridarrays.jl")
include("kron.jl")
include("precompile.jl")

if !isdefined(Base, :get_extension)
using Requires
Expand All @@ -82,4 +87,11 @@ function __init__()
end
end

@compile_workload let
for FT in (Float32, Float64)
AT = Array
precompile_workload(FT, AT)
end
end

end # module Raven
40 changes: 40 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
function precompile_workload(FT, AT)
lcell2d = LobattoCell{FT,AT}(2, 2)
vertices2d = [
SVector{2,FT}(0, 0), # 1
SVector{2,FT}(2, 0), # 2
SVector{2,FT}(0, 2), # 3
SVector{2,FT}(2, 2), # 4
SVector{2,FT}(4, 0), # 5
SVector{2,FT}(4, 2), # 6
]
cells2d = [(1, 2, 3, 4), (4, 2, 6, 5)]

generate(GridManager(lcell2d, brick(FT, 1, 1)))
generate(GridManager(lcell2d, coarsegrid(vertices2d, cells2d)))

lcell3d = LobattoCell{FT,AT}(2, 2, 2)
vertices3d = [
SVector{3,FT}(0, 0, 0), # 1
SVector{3,FT}(2, 0, 0), # 2
SVector{3,FT}(0, 2, 0), # 3
SVector{3,FT}(2, 2, 0), # 4
SVector{3,FT}(0, 0, 2), # 5
SVector{3,FT}(2, 0, 2), # 6
SVector{3,FT}(0, 2, 2), # 7
SVector{3,FT}(2, 2, 2), # 8
SVector{3,FT}(4, 0, 0), # 9
SVector{3,FT}(4, 2, 0), # 10
SVector{3,FT}(4, 0, 2), # 11
SVector{3,FT}(4, 2, 2), # 12
]
cells3d = [(1, 2, 3, 4, 5, 6, 7, 8), (4, 2, 10, 9, 8, 6, 12, 11)]

generate(GridManager(lcell3d, brick(FT, 1, 1, 1)))
generate(GridManager(lcell3d, coarsegrid(vertices3d, cells3d)))

GaussCell{FT,AT}(2, 2)
GaussCell{FT,AT}(2, 2, 2)

return
end

0 comments on commit 52e0cc1

Please sign in to comment.