Skip to content

Commit

Permalink
ssp, working cuda bilinearlens
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Oct 17, 2024
1 parent 94f9709 commit e167465
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 20 deletions.
18 changes: 17 additions & 1 deletion demo/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.4"
manifest_format = "2.0"
project_hash = "0f9fd719da9147b5520938c78261325bcacef3c0"
project_hash = "25d721e6415e0ce9072f7475d234587f3a98fd89"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -162,6 +162,12 @@ git-tree-sha1 = "3f7be532673fc4a22825e7884e9e0e876236b12a"
uuid = "26cce99e-4866-4b6d-ab74-862489e035e0"
version = "0.7.1"

[[deps.BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "1.5.0"

[[deps.Bijections]]
git-tree-sha1 = "d8b0439d2be438a5f2cd68ec158fe08a7b2595b7"
uuid = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
Expand Down Expand Up @@ -430,6 +436,12 @@ git-tree-sha1 = "f9d7112bfff8a19a3a4ea4e03a8e6a91fe8456bf"
uuid = "150eb455-5306-5404-9cee-2592286d6298"
version = "0.6.3"

[[deps.CuNFFT]]
deps = ["AbstractFFTs", "AbstractNFFTs", "CUDA", "LinearAlgebra", "NFFT", "Reexport"]
git-tree-sha1 = "b8245d7a8f1e943f9740ac3786f99feef72cf74b"
uuid = "a9291f20-7f4c-4d50-b30d-4e07b13252e1"
version = "0.3.8"

[[deps.CustomUnitRanges]]
git-tree-sha1 = "1a3f97f907e6dd8983b744d2642651bb162a3f7a"
uuid = "dc8bdbbb-1ca9-579f-8c36-e416f6a65cce"
Expand Down Expand Up @@ -1640,6 +1652,10 @@ version = "0.2.0"
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[deps.ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4"
Expand Down
3 changes: 3 additions & 0 deletions demo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CMBLensing = "b60c06c0-7e54-11e8-3788-4bd722d65317"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CuNFFT = "a9291f20-7f4c-4d50-b30d-4e07b13252e1"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand Down
123 changes: 104 additions & 19 deletions demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -189,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 242,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -205,6 +205,20 @@
");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"σ²κ = 1e-7\n",
"Mborder = ds.M[2][:Q]\n",
"T = real(eltype(ds.Cf))\n",
"ds.logprior = function(;ϕ, _...)\n",
" -(sum(Mborder * (∇² * ϕ)) / sum(diag(Mborder)))^2 / T(2*σ²κ)\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -216,7 +230,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 245,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -230,7 +244,16 @@
"metadata": {},
"outputs": [],
"source": [
"jMAP = @time MAP_joint(ds, nsteps=30, progress=false);"
"JMAP_SSP = @time MAP_joint(ds, nsteps=30, progress=false, prior_deprojection_factor=0);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"JMAP = @time MAP_joint(@set(ds.logprior=(;_...)->0), nsteps=30, progress=false, prior_deprojection_factor=0);"
]
},
{
Expand All @@ -239,7 +262,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot(first.(jMAP.history))"
"FJMAP_SSP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);"
]
},
{
Expand All @@ -248,7 +271,7 @@
"metadata": {},
"outputs": [],
"source": [
"mMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);"
"FJMAP_SSP_BL = @time MAP_marg(@set(ds.L=BilinearLens), nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);"
]
},
{
Expand All @@ -257,7 +280,7 @@
"metadata": {},
"outputs": [],
"source": [
"mffmMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);"
"MMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);"
]
},
{
Expand All @@ -266,10 +289,12 @@
"metadata": {},
"outputs": [],
"source": [
"plot(get_Cℓ(∇²*ϕ), label=\"true\")\n",
"plot!(get_Cℓ(∇²*jMAP.ϕ), label=\"jMAP\")\n",
"plot!(get_Cℓ(∇²*mMAP.ϕ), label=\"mMAP\")\n",
"plot!(get_Cℓ(∇²*mffmMAP.ϕ), label=\"mffmMAP\")\n",
"plot(get_Cℓ(∇²*ϕ), label=\"True\")\n",
"plot!(get_Cℓ(∇²*JMAP.ϕ), label=\"JMAP\")\n",
"plot!(get_Cℓ(∇²*MMAP.ϕ), label=\"MMAP\")\n",
"plot!(get_Cℓ(∇²*JMAP_SSP.ϕ), label=\"JMAP + SSP\")\n",
"plot!(get_Cℓ(∇²*FJMAP_SSP.ϕ), label=\"FJMAP + SSP\")\n",
"plot!(get_Cℓ(∇²*FJMAP_SSP_BL.ϕ), label=\"FJMAP + SSP + BL\")\n",
"plot!(yscale=:log10, ylim=(1e-10,1e-5), xlim=(0,2000))"
]
},
Expand All @@ -279,10 +304,12 @@
"metadata": {},
"outputs": [],
"source": [
"plot(get_ρℓ(∇²*ϕ,∇²*jMAP.ϕ,), label=\"true x jMAP\")\n",
"plot!(get_ρℓ(∇²*ϕ,∇²*mMAP.ϕ), label=\"true x mMAP\")\n",
"plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mMAP.ϕ), label=\"mMAP x jMAP\")\n",
"plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mffmMAP.ϕ), label=\"mffmMAP x jMAP\")\n",
"plot()\n",
"# plot!(get_ρℓ(∇²*ϕ,∇²*JMAP.ϕ,), label=\"true x JMAP\")\n",
"# plot!(get_ρℓ(∇²*ϕ,∇²*MMAP.ϕ), label=\"true x MMAP\")\n",
"# plot!(get_ρℓ(∇²*JMAP.ϕ,∇²*JMAP_SSP.ϕ), label=\"MMAP x JMAP\")\n",
"plot!(get_ρℓ(∇²*FJMAP_SSP.ϕ,∇²*JMAP_SSP.ϕ), label=\"FJMAP+SSP x JMAP+SSP\")\n",
"plot!(get_ρℓ(∇²*FJMAP_SSP_BL.ϕ,∇²*JMAP_SSP.ϕ), label=\"FJMAP+SSP+BL x JMAP+SSP\")\n",
"plot!(ticks=:native)"
]
},
Expand All @@ -293,10 +320,11 @@
"outputs": [],
"source": [
"plot(\n",
" plot(∇² * jMAP.ϕ, title = \"jMAP\"),\n",
" plot(∇² * mffmMAP.ϕ, title = \"mffmMAP\"),\n",
" plot(∇² * mMAP.ϕ, title = \"mMAP\"),\n",
" layout = (1, 3), size = (900,300), cbar = false\n",
" plot(∇² * JMAP.ϕ, title = \"JMAP\"),\n",
" plot(∇² * JMAP_SSP.ϕ, title = \"JMAP + SSP\"),\n",
" plot(∇² * FJMAP_SSP.ϕ, title = \"FJMAP + SSP\"),\n",
" plot(∇² * MMAP.ϕ, title = \"MMAP\"),\n",
" layout = (1, 4), size = (1100,300), cbar = false\n",
")"
]
},
Expand All @@ -311,6 +339,63 @@
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Benchmark"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = Map(f);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# map size (spin-0)\n",
"size(f.arr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# LenseFlow precomputation\n",
"@btime CUDA.@sync precompute!!(LenseFlow(ϕ,10),f);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# applying the precomputed lensing operator\n",
"L = precompute!!(LenseFlow(ϕ,10),f)\n",
"@btime CUDA.@sync L * f;"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# gradient of lensing\n",
"@btime CUDA.@sync gradient(ϕ -> norm(L(ϕ) * f), ϕ);"
]
}
],
"metadata": {
Expand Down

0 comments on commit e167465

Please sign in to comment.