Skip to content

Commit

Permalink
Merge pull request #1198 from kohr-h/admm
Browse files Browse the repository at this point in the history
Add ADMM
  • Loading branch information
Holger Kohr authored Oct 21, 2017
2 parents c9c5acf + b3f3e81 commit 66f8e2a
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 4 deletions.
85 changes: 85 additions & 0 deletions examples/solvers/admm_tomography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Total variation tomography using linearized ADMM.
In this example we solve the optimization problem
min_x ||A(x) - y||_2^2 + lam * ||grad(x)||_1
Where ``A`` is a parallel beam ray transform, ``grad`` the spatial
gradient and ``y`` given noisy data.
The problem is rewritten in decoupled form as
min_x g(L(x))
with a separable sum ``g`` of functionals and the stacked operator ``L``:
g(z) = ||z_1 - g||_2^2 + lam * ||z_2||_1,
( A(x) )
z = L(x) = ( grad(x) ).
See the documentation of the `admm_linearized` solver for further details.
"""

import numpy as np
import odl

# --- Set up the forward operator (ray transform) --- #

# Reconstruction space: functions on the rectangle [-20, 20]^2
# discretized with 300 samples per dimension
reco_space = odl.uniform_discr(
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32')

# Make a parallel beam geometry with flat detector, using 360 angles
geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=180)

# Create the forward operator
ray_trafo = odl.tomo.RayTransform(reco_space, geometry)

# --- Generate artificial data --- #

# Create phantom and noisy projection data
phantom = odl.phantom.shepp_logan(reco_space, modified=True)
data = ray_trafo(phantom)
data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1

# --- Set up the inverse problem --- #

# Gradient operator for the TV part
grad = odl.Gradient(reco_space)

# Stacking of the two operators
L = odl.BroadcastOperator(ray_trafo, grad)

# Data matching and regularization functionals
data_fit = odl.solvers.L2NormSquared(ray_trafo.range).translated(data)
reg_func = 0.015 * odl.solvers.L1Norm(grad.range)
g = odl.solvers.SeparableSum(data_fit, reg_func)

# We don't use the f functional, setting it to zero
f = odl.solvers.ZeroFunctional(L.domain)

# --- Select parameters and solve using ADMM --- #

# Estimated operator norm, add 10 percent for some safety margin
op_norm = 1.1 * odl.power_method_opnorm(L, maxiter=20)

niter = 200 # Number of iterations
sigma = 2.0 # Step size for g.proximal
tau = sigma / op_norm ** 2 # Step size for f.proximal

# Optionally pass a callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Choose a starting point
x = L.domain.zero()

# Run the algorithm
odl.solvers.admm_linearized(x, f, g, L, tau, sigma, niter, callback=callback)

# Display images
phantom.show(title='Phantom')
data.show(title='Simulated data (Sinogram)')
x.show(title='TV reconstruction', force_show=True)
4 changes: 2 additions & 2 deletions examples/solvers/lbfgs_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
hessinv_estimate = odl.ScalingOperator(reco_space, 1 / opnorm ** 2)

# Optionally pass callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow())
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Pick parameters
maxiter = 20
Expand Down
4 changes: 2 additions & 2 deletions examples/solvers/pdhg_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
sigma = 1.0 / op_norm # Step size for the dual variable

# Optionally pass callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow())
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Choose a starting point
x = op.domain.zero()
Expand Down
3 changes: 3 additions & 0 deletions odl/solvers/nonsmooth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .proximal_operators import *
__all__ += proximal_operators.__all__

from .admm import *
__all__ += admm.__all__

from .primal_dual_hybrid_gradient import *
__all__ += primal_dual_hybrid_gradient.__all__

Expand Down
161 changes: 161 additions & 0 deletions odl/solvers/nonsmooth/admm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Alternating Direction method of Multipliers (ADMM) method variants."""

from __future__ import division
from odl.operator import Operator, OpDomainError


__all__ = ('admm_linearized',)


def admm_linearized(x, f, g, L, tau, sigma, niter, **kwargs):
"""Generic linearized ADMM method for convex problems.
ADMM stands for "Alternating Direction Method of Multipliers" and
is a popular convex optimization method. This variant solves problems
of the form ::
min_x [ f(x) + g(Lx) ]
with convex ``f`` and ``g``, and a linear operator ``L``. See Section
4.4 of `[PB2014] <http://web.stanford.edu/~boyd/papers/prox_algs.html>`_
and the Notes for more mathematical details.
Parameters
----------
x : ``L.domain`` element
Starting point of the iteration, updated in-place.
f, g : `Functional`
The functions ``f`` and ``g`` in the problem definition. They
need to implement the ``proximal`` method.
L : linear `Operator`
The linear operator that is composed with ``g`` in the problem
definition. It must fulfill ``L.domain == f.domain`` and
``L.range == g.domain``.
tau, sigma : positive float
Step size parameters for the update of the variables.
niter : non-negative int
Number of iterations.
Other Parameters
----------------
callback : callable, optional
Function called with the current iterate after each iteration.
Notes
-----
Given :math:`x^{(0)}` (the provided ``x``) and
:math:`u^{(0)} = z^{(0)} = 0`, linearized ADMM applies the following
iteration:
.. math::
x^{(k+1)} &= \mathrm{prox}_{\\tau f} \\left[
x^{(k)} - \sigma^{-1}\\tau L^*\\big(
L x^{(k)} - z^{(k)} + u^{(k)}
\\big)
\\right]
z^{(k+1)} &= \mathrm{prox}_{\sigma g}\\left(
L x^{(k+1)} + u^{(k)}
\\right)
u^{(k+1)} &= u^{(k)} + L x^{(k+1)} - z^{(k+1)}
The step size parameters :math:`\\tau` and :math:`\sigma` must satisfy
.. math::
0 < \\tau < \\frac{\sigma}{\|L\|^2}
to guarantee convergence.
The name "linearized ADMM" comes from the fact that in the
minimization subproblem for the :math:`x` variable, this variant
uses a linearization of a quadratic term in the augmented Lagrangian
of the generic ADMM, in order to make the step expressible with
the proximal operator of :math:`f`.
Another name for this algorithm is *split inexact Uzawa method*.
References
----------
[PB2014] Parikh, N and Boyd, S. *Proximal Algorithms*. Foundations and
Trends in Optimization, 1(3) (2014), pp 123-231.
"""
if not isinstance(L, Operator):
raise TypeError('`op` {!r} is not an `Operator` instance'
''.format(L))

if x not in L.domain:
raise OpDomainError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, L.domain))

tau, tau_in = float(tau), tau
if tau <= 0:
raise ValueError('`tau` must be positive, got {}'.format(tau_in))

sigma, sigma_in = float(sigma), sigma
if sigma <= 0:
raise ValueError('`sigma` must be positive, got {}'.format(sigma_in))

niter, niter_in = int(niter), niter
if niter < 0 or niter != niter_in:
raise ValueError('`niter` must be a non-negative integer, got {}'
''.format(niter_in))

# Callback object
callback = kwargs.pop('callback', None)
if callback is not None and not callable(callback):
raise TypeError('`callback` {} is not callable'.format(callback))

# Initialize range variables
z = L.range.zero()
u = L.range.zero()

# Temporary for Lx + u [- z]
tmp_ran = L(x)
# Temporary for L^*(Lx + u - z)
tmp_dom = L.domain.element()

# Store proximals since their initialization may involve computation
prox_tau_f = f.proximal(tau)
prox_sigma_g = g.proximal(sigma)

for _ in range(niter):
# tmp_ran has value Lx^k here
# tmp_dom <- L^*(Lx^k + u^k - z^k)
tmp_ran += u
tmp_ran -= z
L.adjoint(tmp_ran, out=tmp_dom)

# x <- x^k - (tau/sigma) L^*(Lx^k + u^k - z^k)
x.lincomb(1, x, -tau / sigma, tmp_dom)
# x^(k+1) <- prox[tau*f](x)
prox_tau_f(x, out=x)

# tmp_ran <- Lx^(k+1)
L(x, out=tmp_ran)
# z^(k+1) <- prox[sigma*g](Lx^(k+1) + u^k)
prox_sigma_g(tmp_ran + u, out=z) # 1 copy here

# u^(k+1) = u^k + Lx^(k+1) - z^(k+1)
u += tmp_ran
u -= z

if callback is not None:
callback(x)


def admm_linearized_simple(x, f, g, L, tau, sigma, niter, **kwargs):
"""Non-optimized version of ``admm_linearized``.
This function is intended for debugging. It makes a lot of copies and
performs no error checking.
"""
callback = kwargs.pop('callback', None)
z = L.range.zero()
u = L.range.zero()
for _ in range(niter):
x[:] = f.proximal(tau)(x - tau / sigma * L.adjoint(L(x) + u - z))
z = g.proximal(sigma)(L(x) + u)
u = L(x) + u - z
if callback is not None:
callback(x)
76 changes: 76 additions & 0 deletions odl/test/solvers/nonsmooth/admm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2014-2017 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Unit tests for ADMM."""

from __future__ import division
import odl
from odl.solvers import admm_linearized, Callback

from odl.util.testutils import all_almost_equal, noise_element


def test_admm_lin_input_handling():
"""Test to see that input is handled correctly."""

space = odl.uniform_discr(0, 1, 10)

L = odl.ZeroOperator(space)
f = g = odl.solvers.ZeroFunctional(space)

# Check that the algorithm runs. With the above operators and functionals,
# the algorithm should not modify the initial value.
x0 = noise_element(space)
x = x0.copy()
niter = 3

admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter)

assert x == x0

# Check that a provided callback is actually called
class CallbackTest(Callback):
was_called = False

def __call__(self, *args, **kwargs):
self.was_called = True

callback = CallbackTest()
assert not callback.was_called
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter,
callback=callback)
assert callback.was_called


def test_admm_lin_l1():
"""Verify that the correct value is returned for l1 dist optimization.
Solves the optimization problem
min_x ||x - data_1||_1 + 0.5 ||x - data_2||_1
which has optimum value data_1 since the first term dominates.
"""
space = odl.rn(5)

L = odl.IdentityOperator(space)

data_1 = odl.util.testutils.noise_element(space)
data_2 = odl.util.testutils.noise_element(space)

f = odl.solvers.L1Norm(space).translated(data_1)
g = 0.5 * odl.solvers.L1Norm(space).translated(data_2)

x = space.zero()
admm_linearized(x, f, g, L, tau=1.0, sigma=2.0, niter=10)

assert all_almost_equal(x, data_1, places=2)


if __name__ == '__main__':
odl.util.test_file(__file__)

0 comments on commit 66f8e2a

Please sign in to comment.