-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1198 from kohr-h/admm
Add ADMM
- Loading branch information
Showing
6 changed files
with
329 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Total variation tomography using linearized ADMM. | ||
In this example we solve the optimization problem | ||
min_x ||A(x) - y||_2^2 + lam * ||grad(x)||_1 | ||
Where ``A`` is a parallel beam ray transform, ``grad`` the spatial | ||
gradient and ``y`` given noisy data. | ||
The problem is rewritten in decoupled form as | ||
min_x g(L(x)) | ||
with a separable sum ``g`` of functionals and the stacked operator ``L``: | ||
g(z) = ||z_1 - g||_2^2 + lam * ||z_2||_1, | ||
( A(x) ) | ||
z = L(x) = ( grad(x) ). | ||
See the documentation of the `admm_linearized` solver for further details. | ||
""" | ||
|
||
import numpy as np | ||
import odl | ||
|
||
# --- Set up the forward operator (ray transform) --- # | ||
|
||
# Reconstruction space: functions on the rectangle [-20, 20]^2 | ||
# discretized with 300 samples per dimension | ||
reco_space = odl.uniform_discr( | ||
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32') | ||
|
||
# Make a parallel beam geometry with flat detector, using 360 angles | ||
geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=180) | ||
|
||
# Create the forward operator | ||
ray_trafo = odl.tomo.RayTransform(reco_space, geometry) | ||
|
||
# --- Generate artificial data --- # | ||
|
||
# Create phantom and noisy projection data | ||
phantom = odl.phantom.shepp_logan(reco_space, modified=True) | ||
data = ray_trafo(phantom) | ||
data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1 | ||
|
||
# --- Set up the inverse problem --- # | ||
|
||
# Gradient operator for the TV part | ||
grad = odl.Gradient(reco_space) | ||
|
||
# Stacking of the two operators | ||
L = odl.BroadcastOperator(ray_trafo, grad) | ||
|
||
# Data matching and regularization functionals | ||
data_fit = odl.solvers.L2NormSquared(ray_trafo.range).translated(data) | ||
reg_func = 0.015 * odl.solvers.L1Norm(grad.range) | ||
g = odl.solvers.SeparableSum(data_fit, reg_func) | ||
|
||
# We don't use the f functional, setting it to zero | ||
f = odl.solvers.ZeroFunctional(L.domain) | ||
|
||
# --- Select parameters and solve using ADMM --- # | ||
|
||
# Estimated operator norm, add 10 percent for some safety margin | ||
op_norm = 1.1 * odl.power_method_opnorm(L, maxiter=20) | ||
|
||
niter = 200 # Number of iterations | ||
sigma = 2.0 # Step size for g.proximal | ||
tau = sigma / op_norm ** 2 # Step size for f.proximal | ||
|
||
# Optionally pass a callback to the solver to display intermediate results | ||
callback = (odl.solvers.CallbackPrintIteration(step=10) & | ||
odl.solvers.CallbackShow(step=10)) | ||
|
||
# Choose a starting point | ||
x = L.domain.zero() | ||
|
||
# Run the algorithm | ||
odl.solvers.admm_linearized(x, f, g, L, tau, sigma, niter, callback=callback) | ||
|
||
# Display images | ||
phantom.show(title='Phantom') | ||
data.show(title='Simulated data (Sinogram)') | ||
x.show(title='TV reconstruction', force_show=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Alternating Direction method of Multipliers (ADMM) method variants.""" | ||
|
||
from __future__ import division | ||
from odl.operator import Operator, OpDomainError | ||
|
||
|
||
__all__ = ('admm_linearized',) | ||
|
||
|
||
def admm_linearized(x, f, g, L, tau, sigma, niter, **kwargs): | ||
"""Generic linearized ADMM method for convex problems. | ||
ADMM stands for "Alternating Direction Method of Multipliers" and | ||
is a popular convex optimization method. This variant solves problems | ||
of the form :: | ||
min_x [ f(x) + g(Lx) ] | ||
with convex ``f`` and ``g``, and a linear operator ``L``. See Section | ||
4.4 of `[PB2014] <http://web.stanford.edu/~boyd/papers/prox_algs.html>`_ | ||
and the Notes for more mathematical details. | ||
Parameters | ||
---------- | ||
x : ``L.domain`` element | ||
Starting point of the iteration, updated in-place. | ||
f, g : `Functional` | ||
The functions ``f`` and ``g`` in the problem definition. They | ||
need to implement the ``proximal`` method. | ||
L : linear `Operator` | ||
The linear operator that is composed with ``g`` in the problem | ||
definition. It must fulfill ``L.domain == f.domain`` and | ||
``L.range == g.domain``. | ||
tau, sigma : positive float | ||
Step size parameters for the update of the variables. | ||
niter : non-negative int | ||
Number of iterations. | ||
Other Parameters | ||
---------------- | ||
callback : callable, optional | ||
Function called with the current iterate after each iteration. | ||
Notes | ||
----- | ||
Given :math:`x^{(0)}` (the provided ``x``) and | ||
:math:`u^{(0)} = z^{(0)} = 0`, linearized ADMM applies the following | ||
iteration: | ||
.. math:: | ||
x^{(k+1)} &= \mathrm{prox}_{\\tau f} \\left[ | ||
x^{(k)} - \sigma^{-1}\\tau L^*\\big( | ||
L x^{(k)} - z^{(k)} + u^{(k)} | ||
\\big) | ||
\\right] | ||
z^{(k+1)} &= \mathrm{prox}_{\sigma g}\\left( | ||
L x^{(k+1)} + u^{(k)} | ||
\\right) | ||
u^{(k+1)} &= u^{(k)} + L x^{(k+1)} - z^{(k+1)} | ||
The step size parameters :math:`\\tau` and :math:`\sigma` must satisfy | ||
.. math:: | ||
0 < \\tau < \\frac{\sigma}{\|L\|^2} | ||
to guarantee convergence. | ||
The name "linearized ADMM" comes from the fact that in the | ||
minimization subproblem for the :math:`x` variable, this variant | ||
uses a linearization of a quadratic term in the augmented Lagrangian | ||
of the generic ADMM, in order to make the step expressible with | ||
the proximal operator of :math:`f`. | ||
Another name for this algorithm is *split inexact Uzawa method*. | ||
References | ||
---------- | ||
[PB2014] Parikh, N and Boyd, S. *Proximal Algorithms*. Foundations and | ||
Trends in Optimization, 1(3) (2014), pp 123-231. | ||
""" | ||
if not isinstance(L, Operator): | ||
raise TypeError('`op` {!r} is not an `Operator` instance' | ||
''.format(L)) | ||
|
||
if x not in L.domain: | ||
raise OpDomainError('`x` {!r} is not in the domain of `op` {!r}' | ||
''.format(x, L.domain)) | ||
|
||
tau, tau_in = float(tau), tau | ||
if tau <= 0: | ||
raise ValueError('`tau` must be positive, got {}'.format(tau_in)) | ||
|
||
sigma, sigma_in = float(sigma), sigma | ||
if sigma <= 0: | ||
raise ValueError('`sigma` must be positive, got {}'.format(sigma_in)) | ||
|
||
niter, niter_in = int(niter), niter | ||
if niter < 0 or niter != niter_in: | ||
raise ValueError('`niter` must be a non-negative integer, got {}' | ||
''.format(niter_in)) | ||
|
||
# Callback object | ||
callback = kwargs.pop('callback', None) | ||
if callback is not None and not callable(callback): | ||
raise TypeError('`callback` {} is not callable'.format(callback)) | ||
|
||
# Initialize range variables | ||
z = L.range.zero() | ||
u = L.range.zero() | ||
|
||
# Temporary for Lx + u [- z] | ||
tmp_ran = L(x) | ||
# Temporary for L^*(Lx + u - z) | ||
tmp_dom = L.domain.element() | ||
|
||
# Store proximals since their initialization may involve computation | ||
prox_tau_f = f.proximal(tau) | ||
prox_sigma_g = g.proximal(sigma) | ||
|
||
for _ in range(niter): | ||
# tmp_ran has value Lx^k here | ||
# tmp_dom <- L^*(Lx^k + u^k - z^k) | ||
tmp_ran += u | ||
tmp_ran -= z | ||
L.adjoint(tmp_ran, out=tmp_dom) | ||
|
||
# x <- x^k - (tau/sigma) L^*(Lx^k + u^k - z^k) | ||
x.lincomb(1, x, -tau / sigma, tmp_dom) | ||
# x^(k+1) <- prox[tau*f](x) | ||
prox_tau_f(x, out=x) | ||
|
||
# tmp_ran <- Lx^(k+1) | ||
L(x, out=tmp_ran) | ||
# z^(k+1) <- prox[sigma*g](Lx^(k+1) + u^k) | ||
prox_sigma_g(tmp_ran + u, out=z) # 1 copy here | ||
|
||
# u^(k+1) = u^k + Lx^(k+1) - z^(k+1) | ||
u += tmp_ran | ||
u -= z | ||
|
||
if callback is not None: | ||
callback(x) | ||
|
||
|
||
def admm_linearized_simple(x, f, g, L, tau, sigma, niter, **kwargs): | ||
"""Non-optimized version of ``admm_linearized``. | ||
This function is intended for debugging. It makes a lot of copies and | ||
performs no error checking. | ||
""" | ||
callback = kwargs.pop('callback', None) | ||
z = L.range.zero() | ||
u = L.range.zero() | ||
for _ in range(niter): | ||
x[:] = f.proximal(tau)(x - tau / sigma * L.adjoint(L(x) + u - z)) | ||
z = g.proximal(sigma)(L(x) + u) | ||
u = L(x) + u - z | ||
if callback is not None: | ||
callback(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2014-2017 The ODL contributors | ||
# | ||
# This file is part of ODL. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public License, | ||
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | ||
# obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
"""Unit tests for ADMM.""" | ||
|
||
from __future__ import division | ||
import odl | ||
from odl.solvers import admm_linearized, Callback | ||
|
||
from odl.util.testutils import all_almost_equal, noise_element | ||
|
||
|
||
def test_admm_lin_input_handling(): | ||
"""Test to see that input is handled correctly.""" | ||
|
||
space = odl.uniform_discr(0, 1, 10) | ||
|
||
L = odl.ZeroOperator(space) | ||
f = g = odl.solvers.ZeroFunctional(space) | ||
|
||
# Check that the algorithm runs. With the above operators and functionals, | ||
# the algorithm should not modify the initial value. | ||
x0 = noise_element(space) | ||
x = x0.copy() | ||
niter = 3 | ||
|
||
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter) | ||
|
||
assert x == x0 | ||
|
||
# Check that a provided callback is actually called | ||
class CallbackTest(Callback): | ||
was_called = False | ||
|
||
def __call__(self, *args, **kwargs): | ||
self.was_called = True | ||
|
||
callback = CallbackTest() | ||
assert not callback.was_called | ||
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter, | ||
callback=callback) | ||
assert callback.was_called | ||
|
||
|
||
def test_admm_lin_l1(): | ||
"""Verify that the correct value is returned for l1 dist optimization. | ||
Solves the optimization problem | ||
min_x ||x - data_1||_1 + 0.5 ||x - data_2||_1 | ||
which has optimum value data_1 since the first term dominates. | ||
""" | ||
space = odl.rn(5) | ||
|
||
L = odl.IdentityOperator(space) | ||
|
||
data_1 = odl.util.testutils.noise_element(space) | ||
data_2 = odl.util.testutils.noise_element(space) | ||
|
||
f = odl.solvers.L1Norm(space).translated(data_1) | ||
g = 0.5 * odl.solvers.L1Norm(space).translated(data_2) | ||
|
||
x = space.zero() | ||
admm_linearized(x, f, g, L, tau=1.0, sigma=2.0, niter=10) | ||
|
||
assert all_almost_equal(x, data_1, places=2) | ||
|
||
|
||
if __name__ == '__main__': | ||
odl.util.test_file(__file__) |