Skip to content

Commit b839635

Browse files
Merge pull request #228 from alyst/optim_extensions
Support optimization backends via package extensions
2 parents 5b9327e + c5b48c7 commit b839635

36 files changed

+984
-669
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
1212
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
15-
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
1615
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1716
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1817
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
@@ -44,3 +43,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4443

4544
[targets]
4645
test = ["Test"]
46+
47+
[weakdeps]
48+
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
49+
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
50+
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
51+
ProximalOperators = "f3b72e0c-5f3e-4b3e-8f3e-3f4f3e3e3e3e"
52+
53+
[extensions]
54+
SEMNLOptExt = "NLopt"
55+
SEMProximalOptExt = ["ProximalCore", "ProximalAlgorithms", "ProximalOperators"]

ext/SEMNLOptExt/NLopt.jl

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
############################################################################################
2+
### Types
3+
############################################################################################
4+
"""
5+
Connects to `NLopt.jl` as the optimization backend.
6+
7+
# Constructor
8+
9+
SemOptimizerNLopt(;
10+
algorithm = :LD_LBFGS,
11+
options = Dict{Symbol, Any}(),
12+
local_algorithm = nothing,
13+
local_options = Dict{Symbol, Any}(),
14+
equality_constraints = Vector{NLoptConstraint}(),
15+
inequality_constraints = Vector{NLoptConstraint}(),
16+
kwargs...)
17+
18+
# Arguments
19+
- `algorithm`: optimization algorithm.
20+
- `options::Dict{Symbol, Any}`: options for the optimization algorithm
21+
- `local_algorithm`: local optimization algorithm
22+
- `local_options::Dict{Symbol, Any}`: options for the local optimization algorithm
23+
- `equality_constraints::Vector{NLoptConstraint}`: vector of equality constraints
24+
- `inequality_constraints::Vector{NLoptConstraint}`: vector of inequality constraints
25+
26+
# Example
27+
```julia
28+
my_optimizer = SemOptimizerNLopt()
29+
30+
# constrained optimization with augmented lagrangian
31+
my_constrained_optimizer = SemOptimizerNLopt(;
32+
algorithm = :AUGLAG,
33+
local_algorithm = :LD_LBFGS,
34+
local_options = Dict(:ftol_rel => 1e-6),
35+
inequality_constraints = NLoptConstraint(;f = my_constraint, tol = 0.0),
36+
)
37+
```
38+
39+
# Usage
40+
All algorithms and options from the NLopt library are available, for more information see
41+
the NLopt.jl package and the NLopt online documentation.
42+
For information on how to use inequality and equality constraints,
43+
see [Constrained optimization](@ref) in our online documentation.
44+
45+
# Extended help
46+
47+
## Interfaces
48+
- `algorithm(::SemOptimizerNLopt)`
49+
- `local_algorithm(::SemOptimizerNLopt)`
50+
- `options(::SemOptimizerNLopt)`
51+
- `local_options(::SemOptimizerNLopt)`
52+
- `equality_constraints(::SemOptimizerNLopt)`
53+
- `inequality_constraints(::SemOptimizerNLopt)`
54+
55+
## Implementation
56+
57+
Subtype of `SemOptimizer`.
58+
"""
59+
struct SemOptimizerNLopt{A, A2, B, B2, C} <: SemOptimizer{:NLopt}
60+
algorithm::A
61+
local_algorithm::A2
62+
options::B
63+
local_options::B2
64+
equality_constraints::C
65+
inequality_constraints::C
66+
end
67+
68+
Base.@kwdef struct NLoptConstraint
69+
f::Any
70+
tol = 0.0
71+
end
72+
73+
Base.convert(
74+
::Type{NLoptConstraint},
75+
tuple::NamedTuple{(:f, :tol), Tuple{F, T}},
76+
) where {F, T} = NLoptConstraint(tuple.f, tuple.tol)
77+
78+
############################################################################################
79+
### Constructor
80+
############################################################################################
81+
82+
function SemOptimizerNLopt(;
83+
algorithm = :LD_LBFGS,
84+
local_algorithm = nothing,
85+
options = Dict{Symbol, Any}(),
86+
local_options = Dict{Symbol, Any}(),
87+
equality_constraints = Vector{NLoptConstraint}(),
88+
inequality_constraints = Vector{NLoptConstraint}(),
89+
kwargs...,
90+
)
91+
applicable(iterate, equality_constraints) && !isa(equality_constraints, NamedTuple) ||
92+
(equality_constraints = [equality_constraints])
93+
applicable(iterate, inequality_constraints) &&
94+
!isa(inequality_constraints, NamedTuple) ||
95+
(inequality_constraints = [inequality_constraints])
96+
return SemOptimizerNLopt(
97+
algorithm,
98+
local_algorithm,
99+
options,
100+
local_options,
101+
convert.(NLoptConstraint, equality_constraints),
102+
convert.(NLoptConstraint, inequality_constraints),
103+
)
104+
end
105+
106+
SEM.SemOptimizer{:NLopt}(args...; kwargs...) = SemOptimizerNLopt(args...; kwargs...)
107+
108+
############################################################################################
109+
### Recommended methods
110+
############################################################################################
111+
112+
SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) =
113+
optimizer
114+
115+
############################################################################################
116+
### additional methods
117+
############################################################################################
118+
119+
SEM.algorithm(optimizer::SemOptimizerNLopt) = optimizer.algorithm
120+
local_algorithm(optimizer::SemOptimizerNLopt) = optimizer.local_algorithm
121+
SEM.options(optimizer::SemOptimizerNLopt) = optimizer.options
122+
local_options(optimizer::SemOptimizerNLopt) = optimizer.local_options
123+
equality_constraints(optimizer::SemOptimizerNLopt) = optimizer.equality_constraints
124+
inequality_constraints(optimizer::SemOptimizerNLopt) = optimizer.inequality_constraints
125+
126+
mutable struct NLoptResult
127+
result::Any
128+
problem::Any
129+
end
130+
131+
SEM.optimizer(res::NLoptResult) = res.problem.algorithm
132+
SEM.n_iterations(res::NLoptResult) = res.problem.numevals
133+
SEM.convergence(res::NLoptResult) = res.result[3]
134+
135+
# construct SemFit from fitted NLopt object
136+
function SemFit_NLopt(optimization_result, model::AbstractSem, start_val, opt)
137+
return SemFit(
138+
optimization_result[1],
139+
optimization_result[2],
140+
start_val,
141+
model,
142+
NLoptResult(optimization_result, opt),
143+
)
144+
end
145+
146+
# sem_fit method
147+
function SEM.sem_fit(
148+
optim::SemOptimizerNLopt,
149+
model::AbstractSem,
150+
start_params::AbstractVector;
151+
kwargs...,
152+
)
153+
154+
# construct the NLopt problem
155+
opt = construct_NLopt_problem(optim.algorithm, optim.options, length(start_params))
156+
set_NLopt_constraints!(opt, optim)
157+
opt.min_objective =
158+
(par, G) -> SEM.evaluate!(
159+
zero(eltype(par)),
160+
!isnothing(G) && !isempty(G) ? G : nothing,
161+
nothing,
162+
model,
163+
par,
164+
)
165+
166+
if !isnothing(optim.local_algorithm)
167+
opt_local = construct_NLopt_problem(
168+
optim.local_algorithm,
169+
optim.local_options,
170+
length(start_params),
171+
)
172+
opt.local_optimizer = opt_local
173+
end
174+
175+
# fit
176+
result = NLopt.optimize(opt, start_params)
177+
178+
return SemFit_NLopt(result, model, start_params, opt)
179+
end
180+
181+
############################################################################################
182+
### additional functions
183+
############################################################################################
184+
185+
function construct_NLopt_problem(algorithm, options, npar)
186+
opt = Opt(algorithm, npar)
187+
188+
for (key, val) in pairs(options)
189+
setproperty!(opt, key, val)
190+
end
191+
192+
return opt
193+
end
194+
195+
function set_NLopt_constraints!(opt::Opt, optimizer::SemOptimizerNLopt)
196+
for con in optimizer.inequality_constraints
197+
inequality_constraint!(opt, con.f, con.tol)
198+
end
199+
for con in optimizer.equality_constraints
200+
equality_constraint!(opt, con.f, con.tol)
201+
end
202+
end
203+
204+
############################################################################################
205+
# pretty printing
206+
############################################################################################
207+
208+
function Base.show(io::IO, result::NLoptResult)
209+
print(io, "Optimizer status: $(result.result[3]) \n")
210+
print(io, "Minimum: $(round(result.result[1]; digits = 2)) \n")
211+
print(io, "Algorithm: $(result.problem.algorithm) \n")
212+
print(io, "No. evaluations: $(result.problem.numevals) \n")
213+
end

ext/SEMNLOptExt/SEMNLOptExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module SEMNLOptExt
2+
3+
using StructuralEquationModels, NLopt
4+
5+
SEM = StructuralEquationModels
6+
7+
export SemOptimizerNLopt, NLoptConstraint
8+
9+
include("NLopt.jl")
10+
11+
end
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
############################################################################################
2+
### Types
3+
############################################################################################
4+
"""
5+
Connects to `ProximalAlgorithms.jl` as the optimization backend.
6+
7+
# Constructor
8+
9+
SemOptimizerProximal(;
10+
algorithm = ProximalAlgorithms.PANOC(),
11+
operator_g,
12+
operator_h = nothing,
13+
kwargs...,
14+
15+
# Arguments
16+
- `algorithm`: optimization algorithm.
17+
- `operator_g`: gradient of the objective function
18+
- `operator_h`: optional hessian of the objective function
19+
"""
20+
mutable struct SemOptimizerProximal{A, B, C} <: SemOptimizer{:Proximal}
21+
algorithm::A
22+
operator_g::B
23+
operator_h::C
24+
end
25+
26+
SEM.SemOptimizer{:Proximal}(args...; kwargs...) = SemOptimizerProximal(args...; kwargs...)
27+
28+
SemOptimizerProximal(;
29+
algorithm = ProximalAlgorithms.PANOC(),
30+
operator_g,
31+
operator_h = nothing,
32+
kwargs...,
33+
) = SemOptimizerProximal(algorithm, operator_g, operator_h)
34+
35+
############################################################################################
36+
### Recommended methods
37+
############################################################################################
38+
39+
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
40+
optimizer
41+
42+
############################################################################################
43+
### additional methods
44+
############################################################################################
45+
46+
SEM.algorithm(optimizer::SemOptimizerProximal) = optimizer.algorithm
47+
48+
############################################################################
49+
### Pretty Printing
50+
############################################################################
51+
52+
function Base.show(io::IO, struct_inst::SemOptimizerProximal)
53+
print_type_name(io, struct_inst)
54+
print_field_types(io, struct_inst)
55+
end
56+
57+
## connect do ProximalAlgorithms.jl as backend
58+
ProximalCore.gradient!(grad, model::AbstractSem, parameters) =
59+
objective_gradient!(grad, model::AbstractSem, parameters)
60+
61+
mutable struct ProximalResult
62+
result::Any
63+
end
64+
65+
function SEM.sem_fit(
66+
optim::SemOptimizerProximal,
67+
model::AbstractSem,
68+
start_params::AbstractVector;
69+
kwargs...,
70+
)
71+
if isnothing(optim.operator_h)
72+
solution, iterations =
73+
optim.algorithm(x0 = start_params, f = model, g = optim.operator_g)
74+
else
75+
solution, iterations = optim.algorithm(
76+
x0 = start_params,
77+
f = model,
78+
g = optim.operator_g,
79+
h = optim.operator_h,
80+
)
81+
end
82+
83+
minimum = objective!(model, solution)
84+
85+
optimization_result = Dict(
86+
:minimum => minimum,
87+
:iterations => iterations,
88+
:algorithm => optim.algorithm,
89+
:operator_g => optim.operator_g,
90+
)
91+
92+
isnothing(optim.operator_h) ||
93+
push!(optimization_result, :operator_h => optim.operator_h)
94+
95+
return SemFit(
96+
minimum,
97+
solution,
98+
start_params,
99+
model,
100+
ProximalResult(optimization_result),
101+
)
102+
end
103+
104+
############################################################################################
105+
# pretty printing
106+
############################################################################################
107+
108+
function Base.show(io::IO, result::ProximalResult)
109+
print(io, "Minimum: $(round(result.result[:minimum]; digits = 2)) \n")
110+
print(io, "No. evaluations: $(result.result[:iterations]) \n")
111+
print(io, "Operator: $(nameof(typeof(result.result[:operator_g]))) \n")
112+
if haskey(result.result, :operator_h)
113+
print(io, "Second Operator: $(nameof(typeof(result.result[:operator_h]))) \n")
114+
end
115+
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module SEMProximalOptExt
2+
3+
using StructuralEquationModels
4+
using ProximalCore, ProximalAlgorithms, ProximalOperators
5+
6+
export SemOptimizerProximal
7+
8+
SEM = StructuralEquationModels
9+
10+
#ProximalCore.prox!(y, f, x, gamma) = ProximalOperators.prox!(y, f, x, gamma)
11+
12+
include("ProximalAlgorithms.jl")
13+
14+
end

0 commit comments

Comments
 (0)