Skip to content

Commit

Permalink
Orad/fix emulator bug repeat call (#281)
Browse files Browse the repository at this point in the history
* bugfix repeated calls, by skipping build

* typo in test

plots, repeats, more tweaks

format

varying dimension is easier

ens propto dim

adds recompute_cov and inflation/localization for scalar RF

different plots and configs with localization etc.

exp to fit paper results

add statsbase

add plotting when no repeats

add coeffs, hardcode some restart options in scalar RF

add coeffs, hardcode some restart options in scalar RF

kron add regularization matrix

save states and replotting from file

prior plots for l63

added save+plot for GFunc

add prior case

add save+plot to ishigami

updated plotting for EDMF

remove coeffs from loss again

JLD2 in projects

config update

format
  • Loading branch information
odunbar committed Jul 11, 2024
1 parent b6a80c9 commit 39e02d7
Show file tree
Hide file tree
Showing 17 changed files with 1,082 additions and 100 deletions.
124 changes: 113 additions & 11 deletions examples/EDMF_data/plot_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,112 @@ using Dates
# CES
using CalibrateEmulateSample.ParameterDistributions

#####
# Creates 1 plots: One for a specific case, One with 2 cases, and One with all cases (final case being the prior).


# date = Date(year,month,day)

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
date_of_run = Date(2023, 10, 5)
#exp_name = "ent-det-calibration"
#date_of_run = Date(2023, 10, 5)

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
#date_of_run = Date(2023,10,4)
exp_name = "ent-det-tked-tkee-stab-calibration"
date_of_run = Date(2024, 06, 14)

# Output figure read/write directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))
data_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))

#case:
cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-prior",
"RF-vector-svd-nonsep",
]
case_rf = cases[3]

# load
posterior_filepath = joinpath(data_save_directory, "$(case_rf)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
println("Loading posterior distribution from: " * posterior_filepath)
posterior = load(posterior_filepath)["posterior"]
end
# get samples explicitly (may be easier to work with)
posterior_samples = vcat([get_distribution(posterior)[name] for name in get_name(posterior)]...) #samples are columns
transformed_posterior_samples =
mapslices(x -> transform_unconstrained_to_constrained(posterior, x), posterior_samples, dims = 1)

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "$(case_rf)_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "$(case_rf)_posterior_dist_phys.png")
labels = get_name(posterior)

burnin = 50_000

data_rf = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_rf =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)

p = pairplot(data_rf => (PairPlots.Contourf(sigmas = 1:1:3),))
trans_p = pairplot(transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),))

save(density_filepath, p)
save(transformed_density_filepath, trans_p)

#
#
#

case_gp = cases[1]
# load
posterior_filepath = joinpath(data_save_directory, "$(case_gp)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
println("Loading posterior distribution from: " * posterior_filepath)
posterior = load(posterior_filepath)["posterior"]
end
# get samples explicitly (may be easier to work with)
posterior_samples = vcat([get_distribution(posterior)[name] for name in get_name(posterior)]...) #samples are columns
transformed_posterior_samples =
mapslices(x -> transform_unconstrained_to_constrained(posterior, x), posterior_samples, dims = 1)

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "$(case_rf)_$(case_gp)_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "$(case_rf)_$(case_gp)_posterior_dist_phys.png")
labels = get_name(posterior)
data_gp = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_gp =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
#
#
#
gp_smoothing = 1 # >1 = smoothing KDE in plotting

p = pairplot(
data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
)
trans_p = pairplot(
transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
transformed_data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
)

save(density_filepath, p)
save(transformed_density_filepath, trans_p)



# Finally include the prior too
case_prior = cases[2]
# load
posterior_filepath = joinpath(data_save_directory, "posterior.jld2")
posterior_filepath = joinpath(data_save_directory, "$(case_prior)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
Expand All @@ -39,14 +128,27 @@ transformed_posterior_samples =

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "posterior_dist_phys.png")
density_filepath = joinpath(figure_save_directory, "all_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "all_posterior_dist_phys.png")
labels = get_name(posterior)

data = (; [(Symbol(labels[i]), posterior_samples[i, :]) for i in 1:length(labels)]...)
transformed_data = (; [(Symbol(labels[i]), transformed_posterior_samples[i, :]) for i in 1:length(labels)]...)
data_prior = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_prior =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
#
#
#

p = pairplot(
data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
data_prior => (PairPlots.Scatter(),),
)
trans_p = pairplot(
transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
transformed_data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
transformed_data_prior => (PairPlots.Scatter(),),
)

p = pairplot(data => (PairPlots.Scatter(),))
trans_p = pairplot(transformed_data => (PairPlots.Scatter(),))
save(density_filepath, p)
save(transformed_density_filepath, trans_p)
45 changes: 25 additions & 20 deletions examples/EDMF_data/uq_for_edmf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ Random.seed!(rng_seed)
function main()

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
#exp_name = "ent-det-calibration"

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
exp_name = "ent-det-tked-tkee-stab-calibration"

cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-prior",
"RF-vector-svd-nonsep",
]
case = cases[1]

# Output figure save directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(Dates.today()))
Expand Down Expand Up @@ -119,7 +125,7 @@ function main()
println("plotting ensembles...")
for plot_i in 1:size(outputs, 1)
p = scatter(inputs_constrained[1, :], inputs_constrained[2, :], zcolor = outputs[plot_i, :])
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".png"))
savefig(p, joinpath(figure_save_directory, "$(case)_output_" * string(plot_i) * ".png"))
end
println("finished plotting ensembles.")
end
Expand Down Expand Up @@ -198,20 +204,19 @@ function main()
println("Begin Emulation stage")
# Create GP object

cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-vector-svd-nonsep",
]
case = cases[2]

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.95,
"train_fraction" => 0.85,
"scheduler" => DataMisfitController(terminate_at = 100),
"cov_sample_multiplier" => 0.5,
"n_iteration" => 5,
"cov_sample_multiplier" => 1.0,
"n_iteration" => 15,
"n_features_opt" => 200,
"localization" => SEC(0.05),
)
nugget = 0.01
if case == "RF-prior"
overrides = Dict("verbose" => true, "cov_sample_multiplier" => 0.01, "n_iteration" => 0)
end
nugget = 1e-6
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
Expand All @@ -226,8 +231,8 @@ function main()
prediction_type = pred_type,
noise_learn = false,
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
elseif case ["RF-vector-svd-nonsep", "RF-prior"]
kernel_structure = NonseparableKernel(LowRankFactor(1, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
Expand All @@ -248,7 +253,7 @@ function main()
# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)

emulator_filepath = joinpath(data_save_directory, "emulator.jld2")
emulator_filepath = joinpath(data_save_directory, "$(case)_emulator.jld2")
save(emulator_filepath, "emulator", emulator)

println("Finished Emulation stage")
Expand All @@ -264,17 +269,17 @@ function main()
# determine a good step size
yt_sample = y_truth
mcmc = MCMCWrapper(RWMHSampling(), yt_sample, prior, emulator; init_params = u0)
new_step = optimize_stepsize(mcmc; init_stepsize = 0.1, N = 2000, discard_initial = 0)
new_step = optimize_stepsize(mcmc; init_stepsize = 0.1, N = 5000, discard_initial = 0)

# Now begin the actual MCMC
println("Begin MCMC - with step size ", new_step)
chain = MarkovChainMonteCarlo.sample(mcmc, 100_000; stepsize = new_step, discard_initial = 2_000)
chain = MarkovChainMonteCarlo.sample(mcmc, 300_000; stepsize = new_step, discard_initial = 2_000)
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain)

mcmc_filepath = joinpath(data_save_directory, "mcmc_and_chain.jld2")
mcmc_filepath = joinpath(data_save_directory, "$(case)_mcmc_and_chain.jld2")
save(mcmc_filepath, "mcmc", mcmc, "chain", chain)

posterior_filepath = joinpath(data_save_directory, "posterior.jld2")
posterior_filepath = joinpath(data_save_directory, "$(case)_posterior.jld2")
save(posterior_filepath, "posterior", posterior)

println("Finished Sampling stage")
Expand Down
13 changes: 13 additions & 0 deletions examples/Emulator/G-function/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
GlobalSensitivityAnalysis = "1b10255b-6da3-57ce-9089-d24e8517b87e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFeatures = "36c3bae2-c0c3-419d-b3b4-eebadd35c5e5"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Loading

0 comments on commit 39e02d7

Please sign in to comment.