Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634069694
  • Loading branch information
Jake VanderPlas authored and JAX-CFD authors committed May 16, 2024
1 parent 674d815 commit f133faa
Show file tree
Hide file tree
Showing 15 changed files with 43 additions and 43 deletions.
12 changes: 6 additions & 6 deletions jax_cfd/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def slice_along_axis(
Returns:
Slice of `inputs` defined by `idx` along axis `axis`.
"""
arrays, tree_def = jax.tree_flatten(inputs)
arrays, tree_def = jax.tree.flatten(inputs)
ndims = set(a.ndim for a in arrays)
if expect_same_dims and len(ndims) != 1:
raise ValueError('arrays in `inputs` expected to have same ndims, but have '
Expand All @@ -68,7 +68,7 @@ def slice_along_axis(
slc = tuple(idx if j == _normalize_axis(axis, ndim) else slice(None)
for j in range(ndim))
sliced.append(array[slc])
return jax.tree_unflatten(tree_def, sliced)
return jax.tree.unflatten(tree_def, sliced)


def split_along_axis(
Expand Down Expand Up @@ -115,22 +115,22 @@ def split_axis(
Raises:
ValueError: if arrays in `inputs` don't have unique size along `axis`.
"""
arrays, tree_def = jax.tree_flatten(inputs)
arrays, tree_def = jax.tree.flatten(inputs)
axis_shapes = set(a.shape[axis] for a in arrays)
if len(axis_shapes) != 1:
raise ValueError(f'Arrays must have equal sized axis but got {axis_shapes}')
axis_shape, = axis_shapes
splits = [jnp.split(a, axis_shape, axis=axis) for a in arrays]
if not keep_dims:
splits = jax.tree_map(lambda a: jnp.squeeze(a, axis), splits)
splits = jax.tree.map(lambda a: jnp.squeeze(a, axis), splits)
splits = zip(*splits)
return tuple(jax.tree_unflatten(tree_def, leaves) for leaves in splits)
return tuple(jax.tree.unflatten(tree_def, leaves) for leaves in splits)


def concat_along_axis(pytrees, axis):
"""Concatenates `pytrees` along `axis`."""
concat_leaves_fn = lambda *args: jnp.concatenate(args, axis)
return jax.tree_map(concat_leaves_fn, *pytrees)
return jax.tree.map(concat_leaves_fn, *pytrees)


def block_reduce(
Expand Down
22 changes: 11 additions & 11 deletions jax_cfd/base/array_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def test_split_and_concat(self, pytree, idx, axis):
"""Tests that split_along_axis, concat_along_axis return expected shapes."""
split_a, split_b = array_utils.split_along_axis(pytree, idx, axis, False)
with self.subTest('split_shape'):
self.assertEqual(jax.tree_leaves(split_a)[0].shape[axis], idx)
self.assertEqual(jax.tree.leaves(split_a)[0].shape[axis], idx)

reconstruction = array_utils.concat_along_axis([split_a, split_b], axis)
with self.subTest('split_concat_roundtrip_structure'):
actual_tree_def = jax.tree_structure(reconstruction)
expected_tree_def = jax.tree_structure(pytree)
actual_tree_def = jax.tree.structure(reconstruction)
expected_tree_def = jax.tree.structure(pytree)
self.assertSameStructure(actual_tree_def, expected_tree_def)

actual_values = jax.tree_leaves(reconstruction)
expected_values = jax.tree_leaves(pytree)
actual_values = jax.tree.leaves(reconstruction)
expected_values = jax.tree.leaves(pytree)
with self.subTest('split_concat_roundtrip_values'):
for actual, expected in zip(actual_values, expected_values):
self.assertAllClose(actual, expected)
Expand All @@ -157,8 +157,8 @@ def test_split_and_concat(self, pytree, idx, axis):
with self.subTest('multiple_concat_shape'):
arrays = [split_a, split_a, split_b, split_b]
double_concat = array_utils.concat_along_axis(arrays, axis)
actual_shape = jax.tree_leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree_leaves(pytree)[0].shape[axis] * 2
actual_shape = jax.tree.leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree.leaves(pytree)[0].shape[axis] * 2
self.assertEqual(actual_shape, expected_shape)

@parameterized.parameters(
Expand All @@ -171,17 +171,17 @@ def test_split_along_axis_shapes(self, pytree, axis):
with self.subTest('with_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=True)
get_expected_shape = lambda x: x.shape[:axis] + (1,) + x.shape[axis + 1:]
expected_shapes = jax.tree_map(get_expected_shape, pytree)
expected_shapes = jax.tree.map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_map(lambda x: x.shape, split)
actual = jax.tree.map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)

with self.subTest('without_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=False)
get_expected_shape = lambda x: x.shape[:axis] + x.shape[axis + 1:]
expected_shapes = jax.tree_map(get_expected_shape, pytree)
expected_shapes = jax.tree.map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_map(lambda x: x.shape, split)
actual = jax.tree.map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)


Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def sum_fields(*args):
return jax.tree_map(lambda *a: sum(a), *args)
return jax.tree.map(lambda *a: sum(a), *args)


def stable_time_step(
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def stagger(self, v: Tuple[Array, ...]) -> Tuple[GridArray, ...]:
def center(self, v: PyTree) -> PyTree:
"""Places all arrays in the pytree `v` at the `Grid`'s cell center."""
offset = self.cell_center
return jax.tree_map(lambda u: GridArray(u, offset, self), v)
return jax.tree.map(lambda u: GridArray(u, offset, self), v)

def axes(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]:
"""Returns a tuple of arrays containing the grid points along each axis.
Expand Down
4 changes: 2 additions & 2 deletions jax_cfd/base/grids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class GridArrayTest(test_util.TestCase):

def test_tree_util(self):
array = grids.GridArray(jnp.arange(3), offset=(0,), grid=grids.Grid((3,)))
flat, treedef = jax.tree_flatten(array)
roundtripped = jax.tree_unflatten(treedef, flat)
flat, treedef = jax.tree.flatten(array)
roundtripped = jax.tree.unflatten(treedef, flat)
self.assertArrayEqual(array, roundtripped)

def test_consistent_offset(self):
Expand Down
4 changes: 2 additions & 2 deletions jax_cfd/base/subgrid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def wrapped_interp_fn(c, offset, v, dt):
offset: wrapped_interp_fn(viscosity, offset, v, dt).data
for offset in unique_offsets}
viscosities = [viscosities_dict[offset] for offset in s_ij_offsets]
return jax.tree_unflatten(jax.tree_util.tree_structure(s_ij), viscosities)
return jax.tree.unflatten(jax.tree_util.tree_structure(s_ij), viscosities)


def evm_model(
Expand Down Expand Up @@ -127,7 +127,7 @@ def evm_model(
for j in range(grid.ndim)]
for i in range(grid.ndim)])
viscosity = viscosity_fn(s_ij, v)
tau = jax.tree_map(lambda x, y: -2. * x * y, viscosity, s_ij)
tau = jax.tree.map(lambda x, y: -2. * x * y, viscosity, s_ij)
return tuple(-finite_differences.divergence( # pylint: disable=g-complex-comprehension
tuple(grids.GridVariable(t, bc) # use velocity bc to compute diverence
for t in tau[i, :]))
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/collocated/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def sum_fields(*args):
return jax.tree_map(lambda *a: sum(a), *args)
return jax.tree.map(lambda *a: sum(a), *args)


def semi_implicit_navier_stokes(
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/ml/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def slice_last_n_state_encoder(
del grid, dt, physics_specs # unused.
def encode_fn(inputs):
init_slice = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
return jax.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), init_slice)
return jax.tree.map(lambda x: jnp.moveaxis(x, time_axis, -1), init_slice)
return encode_fn


Expand All @@ -118,8 +118,8 @@ def stack_last_n_state_encoder(
del grid, dt, physics_specs # unused.
def encode_fn(inputs):
inputs = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
inputs = jax.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), inputs)
return array_utils.concat_along_axis(jax.tree_leaves(inputs), axis=-1)
inputs = jax.tree.map(lambda x: jnp.moveaxis(x, time_axis, -1), inputs)
return array_utils.concat_along_axis(jax.tree.leaves(inputs), axis=-1)

return encode_fn

Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/ml/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def learned_corrector(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(next_state)
return jax.tree_map(lambda x, y: x + y, next_state, corrections)
return jax.tree.map(lambda x, y: x + y, next_state, corrections)

return hk.to_module(step_fn)()

Expand All @@ -242,7 +242,7 @@ def learned_corrector_v2(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(state)
return jax.tree_map(lambda x, y: x + dt * y, next_state, corrections)
return jax.tree.map(lambda x, y: x + dt * y, next_state, corrections)

return hk.to_module(step_fn)()

Expand All @@ -262,6 +262,6 @@ def learned_corrector_v3(
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(tuple(state) + tuple(next_state))
return jax.tree_map(lambda x, y: x + dt * y, next_state, corrections)
return jax.tree.map(lambda x, y: x + dt * y, next_state, corrections)

return hk.to_module(step_fn)()
2 changes: 1 addition & 1 deletion jax_cfd/ml/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _expand_var(self, var):
def __call__(self, inputs):
input_data = tuple(self._expand_var(var) for var in inputs)
input_data = array_utils.concat_along_axis(
jax.tree_leaves(input_data), axis=-1)
jax.tree.leaves(input_data), axis=-1)
outputs = self._conv_module(input_data)
outputs = array_utils.split_axis(outputs, -1)
outputs = tuple(
Expand Down
4 changes: 2 additions & 2 deletions jax_cfd/ml/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def stack_aligned_field_with_neighbors(

def process(inputs):
inputs = tuple(jnp.expand_dims(x.data, axis=-1) for x in inputs)
array = array_utils.concat_along_axis(jax.tree_leaves(inputs), axis=-1)
array = array_utils.concat_along_axis(jax.tree.leaves(inputs), axis=-1)
arrays = tuple(
jnp.roll(array, *shift_and_axis) for shift_and_axis in shifts_and_axis)
return array_utils.concat_along_axis(arrays, axis=-1)
Expand All @@ -159,7 +159,7 @@ def stack_aligned_field(

def process(inputs):
inputs = tuple(jnp.expand_dims(x.data, axis=-1) for x in inputs)
return array_utils.concat_along_axis(jax.tree_leaves(inputs), axis=-1)
return array_utils.concat_along_axis(jax.tree.leaves(inputs), axis=-1)

return hk.to_module(process)()

Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/ml/time_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def euler_integrator(
"""
def _single_step(state, _):
deriv = derivative_module(state)
next_state = jax.tree_map(lambda x, dxdt: x + dt * dxdt, state, deriv)
next_state = jax.tree.map(lambda x, dxdt: x + dt * dxdt, state, deriv)
return next_state, next_state

return hk.scan(_single_step, initial_state, None, num_steps)
10 changes: 5 additions & 5 deletions jax_cfd/ml/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ def add_noise_to_input_frame(
"""
del kwargs # unused.
time_zero_slice = array_utils.slice_along_axis(batch, 1, 0)
shapes = jax.tree_map(lambda x: x.shape, time_zero_slice)
rngs = jax.random.split(rng, len(jax.tree_leaves(time_zero_slice)))
shapes = jax.tree.map(lambda x: x.shape, time_zero_slice)
rngs = jax.random.split(rng, len(jax.tree.leaves(time_zero_slice)))
# TODO(dkochkov) add `split_like` method to `array_utils.py`.
rngs = jax.tree_unflatten(jax.tree_structure(time_zero_slice), rngs)
rngs = jax.tree.unflatten(jax.tree.structure(time_zero_slice), rngs)
noise_fn = lambda key, s: scale * jax.random.truncated_normal(key, -2., 2., s)
noise = jax.tree_map(noise_fn, rngs, shapes)
noise = jax.tree.map(noise_fn, rngs, shapes)
add_noise_fn = lambda x, n: x.at[:, 0, ...].add(n)
return jax.tree_map(add_noise_fn, batch, noise)
return jax.tree.map(add_noise_fn, batch, noise)


def preprocess(
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/ml/viscosities.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def viscosity_fn(
for offset in unique_offsets}
viscosities = [interpolated_viscosities[offset] for offset in s_ij_offsets]
tree_def = jax.tree_util.tree_structure(s_ij)
return jax.tree_unflatten(tree_def, [x.data for x in viscosities])
return jax.tree.unflatten(tree_def, [x.data for x in viscosities])

return hk.to_module(viscosity_fn)()

Expand Down Expand Up @@ -159,7 +159,7 @@ def viscosity_fn(
for offset in unique_offsets}
viscosities = [interpolated_viscosities[offset] for offset in s_ij_offsets]
tree_def = jax.tree_util.tree_structure(s_ij)
return jax.tree_unflatten(tree_def, [x.data for x in viscosities])
return jax.tree.unflatten(tree_def, [x.data for x in viscosities])

return hk.to_module(viscosity_fn)()

Expand Down Expand Up @@ -205,7 +205,7 @@ def viscosity_fn(
for offset, visc in zip(unique_offsets, viscosities)}
viscosities = [viscosities_dict[offset] for offset in s_ij_offsets]
tree_def = jax.tree_util.tree_structure(s_ij)
return jax.tree_unflatten(tree_def, viscosities)
return jax.tree.unflatten(tree_def, viscosities)

return hk.to_module(viscosity_fn)()

Expand Down
2 changes: 1 addition & 1 deletion notebooks/ml_model_inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
"source": [
"#@title Helper functions\n",
"\n",
"shape_structure = lambda tree: jax.tree_map(lambda x: x.shape, tree)\n",
"shape_structure = lambda tree: jax.tree.map(lambda x: x.shape, tree)\n",
"\n",
"\n",
"def xarray_open(path):\n",
Expand Down

0 comments on commit f133faa

Please sign in to comment.