You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
Currently, the auto-sharding pass for intra-op parallelism (or shard parallel) does not support any control flow instructions (e.g., while and if). For example, the function compute_alpa() below will cause assertion errors.
importjaximportjax.numpyasjnpimportnumpyasnpimportalpaN=1024n_iter=5x=np.ones((N, N), dtype=np.float32)
w=np.ones((N, N), dtype=np.float32)
defcompute_numpy():
y=xforiinrange(n_iter):
y=y @ wreturnydeffunc(a, b):
init_state= (0, x, w)
cond_func=lambdastate: state[0] <n_iterbody_func=lambdastate: (state[0] +1, state[1] @ state[2], state[2])
final_state=jax.lax.while_loop(cond_func, body_func, init_state)
returnfinal_state[1]
defcompute_jax_jit():
returnjax.jit(func)(x, w)
defcompute_alpa():
returnalpa.parallelize(func)(x, w)
# Check correctnessexpected=compute_numpy()
actual=compute_jax_jit()
np.testing.assert_allclose(expected, actual)
# Inspect the HLO IRhlo_text=jax.jit(func).lower(x, w).compile().compiler_ir()[0].to_string()
print(hlo_text)
# Currently, alpa does not support while loop. The following function# causes assertion errors. We want to support it.# actual = compute_alpa()# np.testing.assert_allclose(expected, actual)
To support them, we need to correctly handle HloOpcode::kWhile and HloOpcode::kConditional in the auto-sharding pass.
Todo
Learn the auto-sharding pass
Read the reference materials
Understand the test cases in tests/test_auto_sharding_basic.py and tests/test_auto_sharding_mlp.py
Support while loop
We assume loop length is known at compile time. As the first step, we can simply set the length as a fixed constant (e.g., 5). This can be improved by inferring from the code later. Once we know the loop length, we can unroll the while loop when we build the cost graph. For example, we can multiply the costs of all nodes/edges in the while body by the loop length.
Implement the above idea in auto_sharding.cc and fix all other errors. We should be able to run compute_alpa().
add unit test cases to tests/test_auto_sharding_control_flow.py
Background
Currently, the auto-sharding pass for intra-op parallelism (or shard parallel) does not support any control flow instructions (e.g., while and if). For example, the function
compute_alpa()
below will cause assertion errors.To support them, we need to correctly handle
HloOpcode::kWhile
andHloOpcode::kConditional
in the auto-sharding pass.Todo
tests/test_auto_sharding_basic.py
andtests/test_auto_sharding_mlp.py
We assume loop length is known at compile time. As the first step, we can simply set the length as a fixed constant (e.g., 5). This can be improved by inferring from the code later. Once we know the loop length, we can unroll the while loop when we build the cost graph. For example, we can multiply the costs of all nodes/edges in the while body by the loop length.
auto_sharding.cc
and fix all other errors. We should be able to runcompute_alpa()
.tests/test_auto_sharding_control_flow.py
Reference
The text was updated successfully, but these errors were encountered: