Skip to content

Commit

Permalink
Merge pull request #56 from meom-group/lax-op
Browse files Browse the repository at this point in the history
use higher-order primitives
  • Loading branch information
vadmbertr authored Mar 27, 2024
2 parents b50b89d + d9de54c commit 2f1d0a9
Show file tree
Hide file tree
Showing 39 changed files with 1,696 additions and 497 deletions.
2 changes: 2 additions & 0 deletions docs/duacs_alboran.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. include:: ../notebooks/duacs_alboran/duacs_alboran.md
:parser: myst_parser.sphinx_
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ Examples

gaussian_eddy
alboran_sea
duacs_alboran
218 changes: 148 additions & 70 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numbers
from typing import Literal, Union

from jax import value_and_grad, jit
from jax import jit, lax, value_and_grad
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Array, Float, Scalar
Expand Down Expand Up @@ -142,8 +142,7 @@ def cyclogeostrophy(
mask = sanitize.init_mask(ssh_t, mask)

# Compute geostrophic SSC velocity field
u_geos_u, v_geos_v, lat_u, lon_u, lat_v, lon_v = geostrophy(ssh_t, lat_t, lon_t, mask,
return_grids=True)
u_geos_u, v_geos_v, lat_u, lon_u, lat_v, lon_v = geostrophy(ssh_t, lat_t, lon_t, mask, return_grids=True)

# Compute spatial steps and Coriolis factors
dx_u, dy_u = geometry.compute_spatial_step(lat_u, lon_u)
Expand All @@ -160,7 +159,7 @@ def cyclogeostrophy(
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, mask)

if method == "variational":
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
n_it, optim, optim_kwargs, return_losses)
elif method == "iterative":
res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
Expand Down Expand Up @@ -188,6 +187,68 @@ def cyclogeostrophy(
# Iterative method
# =============================================================================

def _it_step(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
dx_u: Float[Array, "lat lon"],
dx_v: Float[Array, "lat lon"],
dy_u: Float[Array, "lat lon"],
dy_v: Float[Array, "lat lon"],
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
res_eps: float,
res_filter: Float[Array, "lat lon"],
res_weights: Float[Array, "lat lon"],
use_res_filter: bool,
return_losses: bool,
u_cyclo: Float[Array, "lat lon"],
v_cyclo: Float[Array, "lat lon"],
mask_it: Float[Array, "lat lon"],
res_n: Float[Array, "lat lon"],
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"],
Float[Array, "n_it"], int]:
# next it
u_adv_v, v_adv_u = kinematics.advection(u_cyclo, v_cyclo, dx_u, dy_u, dx_v, dy_v, mask)
u_np1 = u_geos_u - v_adv_u / coriolis_factor_u
v_np1 = v_geos_v + u_adv_v / coriolis_factor_v

# compute dist to u_cyclo and v_cyclo
res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo)
res_np1 = lax.cond(
use_res_filter, # apply filter
lambda operands: jsp.signal.convolve(operands[0], operands[1], mode="same", method="fft") / operands[2],
lambda operands: operands[0],
(res_np1, res_filter, res_weights)
)
# compute intermediate masks
mask_jnp1 = jnp.where(res_np1 >= res_eps, 0, 1) # nan comp. equiv. to jnp.where(res_np1 < res_eps, 1, 0)
mask_n = jnp.where(res_np1 <= res_n, 0, 1) # nan comp. equiv. to jnp.where(res_np1 > res_n, 1, 0)

# compute loss
losses = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(_cyclogeostrophic_diff(*operands[2:])),
lambda operands: operands[0],
(losses, i, u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v)
)

# update cyclogeostrophic velocities
u_cyclo = mask_it * u_cyclo + (1 - mask_it) * (mask_n * u_cyclo + (1 - mask_n) * u_np1)
v_cyclo = mask_it * v_cyclo + (1 - mask_it) * (mask_n * v_cyclo + (1 - mask_n) * v_np1)

# update mask and residuals
mask_it = jnp.maximum(mask_it, jnp.maximum(mask_jnp1, mask_n))
res_n = res_np1

i += 1

return u_cyclo, v_cyclo, mask_it, res_n, losses, i


@partial(jit, static_argnames=("n_it", "res_init", "res_filter_size"))
def _iterative(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -198,15 +259,13 @@ def _iterative(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: int,
n_it: Union[int, None],
res_eps: float,
res_init: Union[float, str],
use_res_filter: bool,
res_filter_size: int,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if mask is None:
mask = jnp.zeros_like(u_geos_u)
if n_it is None:
n_it = N_IT_IT
if res_init == "same":
Expand All @@ -216,40 +275,27 @@ def _iterative(
else:
raise ValueError("res_init should be equal to \"same\" or be a number.")

u_cyclo, v_cyclo = jnp.copy(u_geos_u), jnp.copy(v_geos_v)
# used if applying a filter when computing stopping criteria
res_filter = jnp.ones((res_filter_size, res_filter_size))
res_weights = jsp.signal.correlate(jnp.ones_like(u_geos_u), res_filter, mode="same")
losses = jnp.ones(n_it) * jnp.nan
for i in jnp.arange(n_it):
# next it
u_adv_v, v_adv_u = kinematics.advection(u_cyclo, v_cyclo, dx_u, dy_u, dx_v, dy_v)
u_np1 = u_geos_u - v_adv_u / coriolis_factor_u
v_np1 = v_geos_v + u_adv_v / coriolis_factor_v

# compute dist to u_cyclo and v_cyclo
res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo)
if use_res_filter:
res_np1 = jsp.signal.correlate(res_np1, res_filter, mode="same") / res_weights # apply filter
# compute intermediate masks
mask_jnp1 = jnp.where(res_np1 >= res_eps, 0, 1) # nan comp. equiv. to jnp.where(res_np1 < res_eps, 1, 0)
mask_n = jnp.where(res_np1 <= res_n, 0, 1) # nan comp. equiv. to jnp.where(res_np1 > res_n, 1, 0)

# compute loss
if return_losses:
loss = _cyclogeostrophic_diff(u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u,
coriolis_factor_u, coriolis_factor_v)
losses = losses.at[i].set(loss)

# update cyclogeostrophic velocities
u_cyclo = mask * u_cyclo + (1 - mask) * (mask_n * u_cyclo + (1 - mask_n) * u_np1)
v_cyclo = mask * v_cyclo + (1 - mask) * (mask_n * v_cyclo + (1 - mask_n) * v_np1)

# update mask and residuals
mask = jnp.maximum(mask, jnp.maximum(mask_jnp1, mask_n))
res_n = res_np1

if jnp.all(mask == 1):
break
res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft")

# define step partial: freeze constant over iterations
def step_fn(pytree):
return _it_step(
u_geos_u, v_geos_v,
dx_u, dx_v, dy_u, dy_v,
coriolis_factor_u, coriolis_factor_v, mask,
res_eps, res_filter, res_weights,
use_res_filter, return_losses,
*pytree
)

# apply updates
u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa
lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1),
step_fn,
(u_geos_u, v_geos_v, mask.astype(int), res_n, jnp.ones(n_it) * jnp.nan, 0)
)

return u_cyclo, v_cyclo, losses

Expand All @@ -258,47 +304,80 @@ def _iterative(
# Variational method
# =============================================================================

@partial(jit, static_argnums=(0, 3))
def _step(
f: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
def _var_loss_fn(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
dx_u: Float[Array, "lat lon"],
dx_v: Float[Array, "lat lon"],
dy_u: Float[Array, "lat lon"],
dy_v: Float[Array, "lat lon"],
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
uv_cyclo: [Float[Array, "lat lon"], Float[Array, "lat lon"]]
) -> Float[Scalar, ""]:
u_cyclo_u, v_cyclo_v = uv_cyclo
u_adv_v, v_adv_u = kinematics.advection(u_cyclo_u, v_cyclo_v, dx_u, dy_u, dx_v, dy_v, mask)
return _cyclogeostrophic_diff(u_geos_u, v_geos_v, u_cyclo_u, v_cyclo_v, u_adv_v, v_adv_u,
coriolis_factor_u, coriolis_factor_v)


def _var_step(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
optim: optax.GradientTransformation,
return_losses: bool,
u_cyclo_u: Float[Array, "lat lon"],
v_cyclo_v: Float[Array, "lat lon"],
optim: optax.GradientTransformation,
opt_state: optax.OptState
opt_state: optax.OptState,
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], ...]:
params = (u_cyclo_u, v_cyclo_v)
# evaluate the cost function and compute its gradient
loss, grads = value_and_grad(f)(params)
loss, grads = value_and_grad(loss_fn)(params)
# make sure to remove nan values
grads = (sanitize.sanitize_data(grads[0], 0), sanitize.sanitize_data(grads[1], 0))
grads = (sanitize.sanitize_data(grads[0], 0., mask), sanitize.sanitize_data(grads[1], 0., mask))
# update the optimizer
updates, opt_state = optim.update(grads, opt_state, params)
# apply updates to the parameters
u_n, v_n = optax.apply_updates(params, updates)
return u_n, v_n, opt_state, loss

# store loss
losses = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(operands[2]), lambda operands: operands[0],
(losses, i, loss)
)

def _gradient_descent(
i += 1

return u_n, v_n, opt_state, losses, i


def _solve(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
f: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
# initial guess is the geostrophy
u_cyclo_u, v_cyclo_v = jnp.copy(u_geos_u), jnp.copy(v_geos_v)
opt_state = optim.init((u_cyclo_u, v_cyclo_v))
losses = jnp.ones(n_it) * jnp.nan
for i in jnp.arange(n_it):
# update the solution
u_cyclo_u, v_cyclo_v, opt_state, loss = _step(f, u_cyclo_u, v_cyclo_v, optim, opt_state)
if return_losses:
losses = losses.at[i].set(loss)
# define step partial: freeze constant over iterations
def step_fn(pytree):
return _var_step(mask, loss_fn, optim, return_losses, *pytree)

u_cyclo_u, v_cyclo_v, opt_state, losses, i = lax.while_loop( # noqa
lambda args: args[-1] < n_it,
step_fn,
(u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v)), jnp.ones(n_it) * jnp.nan, 0)
)

return u_cyclo_u, v_cyclo_v, losses


@partial(jit, static_argnames=("n_it", "optim", "optim_kwargs"))
def _variational(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -308,14 +387,14 @@ def _variational(
dy_v: Float[Array, "lat lon"],
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
n_it: int,
mask: Float[Array, "lat lon"],
n_it: Union[int, None],
optim: Union[optax.GradientTransformation, str],
optim_kwargs: dict,
optim_kwargs: Union[dict, None],
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if n_it is None:
n_it = N_IT_VAR

if isinstance(optim, str):
if optim_kwargs is None:
optim_kwargs = {"learning_rate": LR_VAR}
Expand All @@ -324,14 +403,13 @@ def _variational(
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such an "
"optimizer.")

# define the cost function
def f(uv: [Float[Array, "lat lon"], Float[Array, "lat lon"]]) -> Float[Array, "lat lon"]:
u, v = uv
u_adv_v, v_adv_u = kinematics.advection(u, v, dx_u, dy_u, dx_v, dy_v)
return _cyclogeostrophic_diff(u_geos_u, v_geos_v, u, v, u_adv_v, v_adv_u,
coriolis_factor_u, coriolis_factor_v)
# define loss partial: freeze constant over iterations
loss_fn = partial(
_var_loss_fn,
u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask
)

return _gradient_descent(u_geos_u, v_geos_v, f, n_it, optim, return_losses)
return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim, return_losses)


def _cyclogeostrophic_diff(
Expand All @@ -343,7 +421,7 @@ def _cyclogeostrophic_diff(
v_adv_u: Float[Array, "lat lon"],
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"]
) -> Float[Array, "lat lon"]:
) -> Float[Scalar, ""]:
J_u = jnp.nansum((u_cyclo_u + v_adv_u / coriolis_factor_u - u_geos_u) ** 2)
J_v = jnp.nansum((v_cyclo_v - u_adv_v / coriolis_factor_v - v_geos_v) ** 2)
return J_u + J_v
2 changes: 2 additions & 0 deletions jaxparrow/geostrophy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

Expand Down Expand Up @@ -81,6 +82,7 @@ def geostrophy(
return res


@jax.jit
def _geostrophy(
ssh_t: Float[Array, "lat lon"],
dx_t: Float[Array, "lat lon"],
Expand Down
Loading

0 comments on commit 2f1d0a9

Please sign in to comment.