From e167465578e29b08ed70ab0bd436cbc9a456460c Mon Sep 17 00:00:00 2001 From: Marius Millea Date: Thu, 17 Oct 2024 06:54:33 +0000 Subject: [PATCH] ssp, working cuda bilinearlens --- demo/Manifest.toml | 18 ++++++- demo/Project.toml | 3 ++ demo/demo.ipynb | 123 ++++++++++++++++++++++++++++++++++++++------- 3 files changed, 124 insertions(+), 20 deletions(-) diff --git a/demo/Manifest.toml b/demo/Manifest.toml index 1785f86d..352eeb9c 100644 --- a/demo/Manifest.toml +++ b/demo/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.4" manifest_format = "2.0" -project_hash = "0f9fd719da9147b5520938c78261325bcacef3c0" +project_hash = "25d721e6415e0ce9072f7475d234587f3a98fd89" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -162,6 +162,12 @@ git-tree-sha1 = "3f7be532673fc4a22825e7884e9e0e876236b12a" uuid = "26cce99e-4866-4b6d-ab74-862489e035e0" version = "0.7.1" +[[deps.BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.5.0" + [[deps.Bijections]] git-tree-sha1 = "d8b0439d2be438a5f2cd68ec158fe08a7b2595b7" uuid = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" @@ -430,6 +436,12 @@ git-tree-sha1 = "f9d7112bfff8a19a3a4ea4e03a8e6a91fe8456bf" uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" +[[deps.CuNFFT]] +deps = ["AbstractFFTs", "AbstractNFFTs", "CUDA", "LinearAlgebra", "NFFT", "Reexport"] +git-tree-sha1 = "b8245d7a8f1e943f9740ac3786f99feef72cf74b" +uuid = "a9291f20-7f4c-4d50-b30d-4e07b13252e1" +version = "0.3.8" + [[deps.CustomUnitRanges]] git-tree-sha1 = "1a3f97f907e6dd8983b744d2642651bb162a3f7a" uuid = "dc8bdbbb-1ca9-579f-8c36-e416f6a65cce" @@ -1640,6 +1652,10 @@ version = "0.2.0" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" diff --git a/demo/Project.toml b/demo/Project.toml index 1bfa5cef..29badf3f 100644 --- a/demo/Project.toml +++ b/demo/Project.toml @@ -1,8 +1,11 @@ [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CMBLensing = "b60c06c0-7e54-11e8-3788-4bd722d65317" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CuNFFT = "a9291f20-7f4c-4d50-b30d-4e07b13252e1" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d" PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/demo/demo.ipynb b/demo/demo.ipynb index 49ced042..c06c0863 100644 --- a/demo/demo.ipynb +++ b/demo/demo.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 242, "metadata": {}, "outputs": [], "source": [ @@ -205,6 +205,20 @@ ");" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "σ²κ = 1e-7\n", + "Mborder = ds.M[2][:Q]\n", + "T = real(eltype(ds.Cf))\n", + "ds.logprior = function(;ϕ, _...)\n", + " -(sum(Mborder * (∇² * ϕ)) / sum(diag(Mborder)))^2 / T(2*σ²κ)\n", + "end" + ] + }, { "cell_type": "code", "execution_count": null, @@ -216,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 245, "metadata": {}, "outputs": [], "source": [ @@ -230,7 +244,16 @@ "metadata": {}, "outputs": [], "source": [ - "jMAP = @time MAP_joint(ds, nsteps=30, progress=false);" + "JMAP_SSP = @time MAP_joint(ds, nsteps=30, progress=false, prior_deprojection_factor=0);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "JMAP = @time MAP_joint(@set(ds.logprior=(;_...)->0), nsteps=30, progress=false, prior_deprojection_factor=0);" ] }, { @@ -239,7 +262,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot(first.(jMAP.history))" + "FJMAP_SSP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);" ] }, { @@ -248,7 +271,7 @@ "metadata": {}, "outputs": [], "source": [ - "mMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);" + "FJMAP_SSP_BL = @time MAP_marg(@set(ds.L=BilinearLens), nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);" ] }, { @@ -257,7 +280,7 @@ "metadata": {}, "outputs": [], "source": [ - "mffmMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=0, Nsims=0, nsteps=30, α=0.5, progress=false);" + "MMAP = @time MAP_marg(ds, nsteps_with_meanfield_update=30, Nsims=30, nsteps=30, α=0.5, progress=false);" ] }, { @@ -266,10 +289,12 @@ "metadata": {}, "outputs": [], "source": [ - "plot(get_Cℓ(∇²*ϕ), label=\"true\")\n", - "plot!(get_Cℓ(∇²*jMAP.ϕ), label=\"jMAP\")\n", - "plot!(get_Cℓ(∇²*mMAP.ϕ), label=\"mMAP\")\n", - "plot!(get_Cℓ(∇²*mffmMAP.ϕ), label=\"mffmMAP\")\n", + "plot(get_Cℓ(∇²*ϕ), label=\"True\")\n", + "plot!(get_Cℓ(∇²*JMAP.ϕ), label=\"JMAP\")\n", + "plot!(get_Cℓ(∇²*MMAP.ϕ), label=\"MMAP\")\n", + "plot!(get_Cℓ(∇²*JMAP_SSP.ϕ), label=\"JMAP + SSP\")\n", + "plot!(get_Cℓ(∇²*FJMAP_SSP.ϕ), label=\"FJMAP + SSP\")\n", + "plot!(get_Cℓ(∇²*FJMAP_SSP_BL.ϕ), label=\"FJMAP + SSP + BL\")\n", "plot!(yscale=:log10, ylim=(1e-10,1e-5), xlim=(0,2000))" ] }, @@ -279,10 +304,12 @@ "metadata": {}, "outputs": [], "source": [ - "plot(get_ρℓ(∇²*ϕ,∇²*jMAP.ϕ,), label=\"true x jMAP\")\n", - "plot!(get_ρℓ(∇²*ϕ,∇²*mMAP.ϕ), label=\"true x mMAP\")\n", - "plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mMAP.ϕ), label=\"mMAP x jMAP\")\n", - "plot!(get_ρℓ(∇²*jMAP.ϕ,∇²*mffmMAP.ϕ), label=\"mffmMAP x jMAP\")\n", + "plot()\n", + "# plot!(get_ρℓ(∇²*ϕ,∇²*JMAP.ϕ,), label=\"true x JMAP\")\n", + "# plot!(get_ρℓ(∇²*ϕ,∇²*MMAP.ϕ), label=\"true x MMAP\")\n", + "# plot!(get_ρℓ(∇²*JMAP.ϕ,∇²*JMAP_SSP.ϕ), label=\"MMAP x JMAP\")\n", + "plot!(get_ρℓ(∇²*FJMAP_SSP.ϕ,∇²*JMAP_SSP.ϕ), label=\"FJMAP+SSP x JMAP+SSP\")\n", + "plot!(get_ρℓ(∇²*FJMAP_SSP_BL.ϕ,∇²*JMAP_SSP.ϕ), label=\"FJMAP+SSP+BL x JMAP+SSP\")\n", "plot!(ticks=:native)" ] }, @@ -293,10 +320,11 @@ "outputs": [], "source": [ "plot(\n", - " plot(∇² * jMAP.ϕ, title = \"jMAP\"),\n", - " plot(∇² * mffmMAP.ϕ, title = \"mffmMAP\"),\n", - " plot(∇² * mMAP.ϕ, title = \"mMAP\"),\n", - " layout = (1, 3), size = (900,300), cbar = false\n", + " plot(∇² * JMAP.ϕ, title = \"JMAP\"),\n", + " plot(∇² * JMAP_SSP.ϕ, title = \"JMAP + SSP\"),\n", + " plot(∇² * FJMAP_SSP.ϕ, title = \"FJMAP + SSP\"),\n", + " plot(∇² * MMAP.ϕ, title = \"MMAP\"),\n", + " layout = (1, 4), size = (1100,300), cbar = false\n", ")" ] }, @@ -311,6 +339,63 @@ "cell_type": "markdown", "metadata": {}, "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f = Map(f);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# map size (spin-0)\n", + "size(f.arr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LenseFlow precomputation\n", + "@btime CUDA.@sync precompute!!(LenseFlow(ϕ,10),f);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# applying the precomputed lensing operator\n", + "L = precompute!!(LenseFlow(ϕ,10),f)\n", + "@btime CUDA.@sync L * f;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gradient of lensing\n", + "@btime CUDA.@sync gradient(ϕ -> norm(L(ϕ) * f), ϕ);" + ] } ], "metadata": {