diff --git a/Project.toml b/Project.toml index 1bd335f1..5937930d 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ NLSolversBase = "7" NLopt = "0.6, 1" Optim = "1" PrettyTables = "2" +ProximalAlgorithms = "0.7" StatsBase = "0.33, 0.34" Symbolics = "4, 5, 6" SymbolicUtils = "1.4 - 1.5, 1.7, 2, 3" @@ -47,9 +48,7 @@ test = ["Test"] [weakdeps] NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" -ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" -ProximalOperators = "f3b72e0c-5f3e-4b3e-8f3e-3f4f3e3e3e3e" [extensions] SEMNLOptExt = "NLopt" -SEMProximalOptExt = ["ProximalCore", "ProximalAlgorithms", "ProximalOperators"] +SEMProximalOptExt = "ProximalAlgorithms" diff --git a/ext/SEMProximalOptExt/ProximalAlgorithms.jl b/ext/SEMProximalOptExt/ProximalAlgorithms.jl index 13debf79..f82c2b00 100644 --- a/ext/SEMProximalOptExt/ProximalAlgorithms.jl +++ b/ext/SEMProximalOptExt/ProximalAlgorithms.jl @@ -54,9 +54,14 @@ function Base.show(io::IO, struct_inst::SemOptimizerProximal) print_field_types(io, struct_inst) end -## connect do ProximalAlgorithms.jl as backend -ProximalCore.gradient!(grad, model::AbstractSem, parameters) = - objective_gradient!(grad, model::AbstractSem, parameters) +## connect to ProximalAlgorithms.jl +function ProximalAlgorithms.value_and_gradient(model::AbstractSem, params) + grad = similar(params) + obj = SEM.evaluate!(zero(eltype(params)), grad, nothing, model, params) + return obj, grad +end + +#ProximalCore.prox!(y, f, x, gamma) = ProximalOperators.prox!(y, f, x, gamma) mutable struct ProximalResult result::Any diff --git a/ext/SEMProximalOptExt/SEMProximalOptExt.jl b/ext/SEMProximalOptExt/SEMProximalOptExt.jl index 8f91e03b..15631136 100644 --- a/ext/SEMProximalOptExt/SEMProximalOptExt.jl +++ b/ext/SEMProximalOptExt/SEMProximalOptExt.jl @@ -1,14 +1,12 @@ module SEMProximalOptExt using StructuralEquationModels -using ProximalCore, ProximalAlgorithms, ProximalOperators +using ProximalAlgorithms export SemOptimizerProximal SEM = StructuralEquationModels -#ProximalCore.prox!(y, f, x, gamma) = ProximalOperators.prox!(y, f, x, gamma) - include("ProximalAlgorithms.jl") end diff --git a/test/Project.toml b/test/Project.toml index 14bd0bec..59db0b15 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,6 @@ NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" -ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/test/examples/proximal/l0.jl b/test/examples/proximal/l0.jl index e8874fd5..da20f390 100644 --- a/test/examples/proximal/l0.jl +++ b/test/examples/proximal/l0.jl @@ -1,4 +1,4 @@ -using StructuralEquationModels, Test, ProximalCore, ProximalAlgorithms, ProximalOperators +using StructuralEquationModels, Test, ProximalAlgorithms, ProximalOperators # load data dat = example_data("political_democracy") diff --git a/test/examples/proximal/lasso.jl b/test/examples/proximal/lasso.jl index 31a4073f..314453df 100644 --- a/test/examples/proximal/lasso.jl +++ b/test/examples/proximal/lasso.jl @@ -1,4 +1,4 @@ -using StructuralEquationModels, Test, ProximalCore, ProximalAlgorithms, ProximalOperators +using StructuralEquationModels, Test, ProximalAlgorithms, ProximalOperators # load data dat = example_data("political_democracy") diff --git a/test/examples/proximal/ridge.jl b/test/examples/proximal/ridge.jl index 12091023..16a318a1 100644 --- a/test/examples/proximal/ridge.jl +++ b/test/examples/proximal/ridge.jl @@ -1,4 +1,4 @@ -using StructuralEquationModels, Test, ProximalCore, ProximalAlgorithms, ProximalOperators +using StructuralEquationModels, Test, ProximalAlgorithms, ProximalOperators # load data dat = example_data("political_democracy")