Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[FEATURE] Control flow support in shard_parallel #400

Open
7 tasks
merrymercy opened this issue Apr 23, 2022 · 5 comments
Open
7 tasks

[FEATURE] Control flow support in shard_parallel #400

merrymercy opened this issue Apr 23, 2022 · 5 comments
Assignees
Labels
enhancement New feature

Comments

@merrymercy
Copy link
Member

merrymercy commented Apr 23, 2022

Background

Currently, the auto-sharding pass for intra-op parallelism (or shard parallel) does not support any control flow instructions (e.g., while and if). For example, the function compute_alpa() below will cause assertion errors.

import jax
import jax.numpy as jnp
import numpy as np

import alpa

N = 1024
n_iter = 5

x = np.ones((N, N), dtype=np.float32)
w = np.ones((N, N), dtype=np.float32)


def compute_numpy():
    y = x
    for i in range(n_iter):
        y = y @ w
    return y


def func(a, b):
    init_state = (0, x, w)
    cond_func = lambda state: state[0] < n_iter
    body_func = lambda state: (state[0] + 1, state[1] @ state[2], state[2])

    final_state = jax.lax.while_loop(cond_func, body_func, init_state)
    return final_state[1]


def compute_jax_jit():
    return jax.jit(func)(x, w)


def compute_alpa():
    return alpa.parallelize(func)(x, w)


# Check correctness
expected = compute_numpy()
actual = compute_jax_jit()
np.testing.assert_allclose(expected, actual)

# Inspect the HLO IR
hlo_text = jax.jit(func).lower(x, w).compile().compiler_ir()[0].to_string()
print(hlo_text)

# Currently, alpa does not support while loop. The following function
# causes assertion errors. We want to support it.
# actual = compute_alpa()
# np.testing.assert_allclose(expected, actual)

To support them, we need to correctly handle HloOpcode::kWhile and HloOpcode::kConditional in the auto-sharding pass.

Todo

  • Learn the auto-sharding pass
  • Support while loop
    We assume loop length is known at compile time. As the first step, we can simply set the length as a fixed constant (e.g., 5). This can be improved by inferring from the code later. Once we know the loop length, we can unroll the while loop when we build the cost graph. For example, we can multiply the costs of all nodes/edges in the while body by the loop length.
    • Implement the above idea in auto_sharding.cc and fix all other errors. We should be able to run compute_alpa().
    • add unit test cases to tests/test_auto_sharding_control_flow.py
  • Support if

Reference

@merrymercy
Copy link
Member Author

cc @HeydrichBeillschmidt

@merrymercy merrymercy changed the title While loop support in shard_parallel Control flow support in shard_parallel Apr 23, 2022
@merrymercy merrymercy changed the title Control flow support in shard_parallel [FEATURE] Control flow support in shard_parallel Apr 23, 2022
@yf225
Copy link
Contributor

yf225 commented Jun 14, 2022

I suspect if is a more commonly used control flow than while, and we can consider implementing if first.

@zhisbug suggested we can just map it to https://www.tensorflow.org/xla/operation_semantics#conditional for now and ignore the effect on training plan, and we can fix any imbalance issue later.

@merrymercy
Copy link
Member Author

@HeydrichBeillschmidt Could you share a little bit about your progress?

@merrymercy merrymercy self-assigned this Aug 6, 2022
@mmorinag127
Copy link

mmorinag127 commented Aug 30, 2022

Is there any news on this topic?
I really want to use it with the alpa:)

@merrymercy
Copy link
Member Author

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature
Projects
None yet
Development

No branches or pull requests

3 participants