Skip to content

Commit cfa694a

Browse files
sourabh2k15The precondition Authors
authored andcommitted
upgrading init2winit from pmap to jit
PiperOrigin-RevId: 673695362
1 parent 6f23374 commit cfa694a

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

precondition/distributed_shampoo.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,10 +2832,7 @@ def _pmap_compute_preconditioners(states, step, statistics,
28322832
Returns:
28332833
New optimizer states after computing the preconditioner.
28342834
"""
2835-
if batch_axis_name:
2836-
num_devices = lax.psum(1, batch_axis_name)
2837-
else:
2838-
num_devices = 1
2835+
num_devices = jax.device_count()
28392836
num_statistics = len(statistics)
28402837
# Pad statistics and exponents to next multiple of num_devices.
28412838
packed_statistics = [
@@ -3033,7 +3030,7 @@ def _pmap_quantized_compute_preconditioners(states, step, statistics,
30333030
Returns:
30343031
New optimizer states after computing the preconditioner.
30353032
"""
3036-
num_devices = lax.psum(1, batch_axis_name)
3033+
num_devices = jax.device_count()
30373034
num_statistics = len(statistics)
30383035
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
30393036
# Complexity here is around: shapes needing be statically shaped,

0 commit comments

Comments
 (0)