Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 4a73f3a

Browse files
authored
[FEATURE] Serialize Parallel Plan (#587)
1 parent 57437d4 commit 4a73f3a

15 files changed

+357
-55
lines changed

alpa/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from alpa.parallel_method import (ShardParallel, PipeshardParallel,
1515
DataParallel, Zero2Parallel, Zero3Parallel,
1616
CreateStateParallel)
17+
from alpa.parallel_plan import plan_to_method
1718
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
1819
from alpa.pipeline_parallel.layer_construction import (manual_remat,
1920
automatic_remat,

alpa/create_state_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self,
3838
super().__init__(mesh_group=mesh_group,
3939
pipeshard_config=pipeshard_config,
4040
num_batch=1,
41+
layer_option=None,
4142
in_tree=in_tree,
4243
out_tree=out_tree,
4344
static_argnums=static_argnums)

alpa/mesh_executable.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
2020
from jax.core import ShapedArray
2121
from jax.interpreters import pxla
22-
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
22+
from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef
2323
import numpy as np
2424
import ray
2525

2626
from alpa.device_mesh import (LocalPhysicalDeviceMesh,
2727
DistributedPhysicalDeviceMesh, RemoteArrayRef,
2828
next_array_uuids)
2929
from alpa.global_env import global_config
30-
from alpa.parallel_plan import PlacementSpec, StagePlan
30+
from alpa.parallel_plan import (PlacementSpec, StagePlan, ClusterInfo,
31+
ParallelPlan)
3132
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
3233
get_input_output_sharding_specs,
3334
make_replicated_spec, HloStatus,
@@ -76,6 +77,10 @@ def get_output_placement_specs(self):
7677
"""
7778
raise NotImplementedError()
7879

80+
def get_parallel_plan(self):
81+
"""Get the overall parallel plan."""
82+
raise NotImplementedError()
83+
7984
def preshard_dynamic_args(self, *args):
8085
"""Pre-shard the input arguments."""
8186
raise NotImplementedError()
@@ -205,6 +210,7 @@ def __init__(self,
205210
self.out_tree = out_tree
206211
self.flop_count = flop_count
207212
self.stage_plan = stage_plan
213+
self.auto_sharding_option = stage_plan.auto_sharding_option
208214
self.auto_sharding_objective = stage_plan.auto_sharding_objective
209215

210216
# Read sharding specs
@@ -324,6 +330,13 @@ def get_output_placement_specs(self):
324330
self.output_sharding_specs,
325331
self.out_tree)
326332

333+
def get_parallel_plan(self):
334+
"""Get the overall parallel plan."""
335+
cluster_info = ClusterInfo(self.physical_mesh.num_hosts,
336+
self.physical_mesh.num_devices_per_host)
337+
return ParallelPlan(cluster_info, None, self.auto_sharding_option, None,
338+
tree_leaves(self.get_input_placement_specs()))
339+
327340
def preshard_dynamic_args(self, *args):
328341
"""Pre-shard the input arguments."""
329342
input_bufs = self.physical_mesh.shard_args_to_bufs(
@@ -517,6 +530,7 @@ def __init__(self,
517530
self.out_tree = out_tree
518531
self.flop_count = flop_count
519532
self.stage_plan = stage_plan
533+
self.auto_sharding_option = stage_plan.auto_sharding_option
520534
self.auto_sharding_objective = stage_plan.auto_sharding_objective
521535

522536
# Read sharding specs
@@ -753,6 +767,14 @@ def get_output_placement_specs(self):
753767
self.output_sharding_specs,
754768
self.out_tree)
755769

770+
def get_parallel_plan(self):
771+
"""Get the overall parallel plan."""
772+
cluster_info = ClusterInfo(self.physical_mesh.num_hosts,
773+
self.physical_mesh.num_devices_per_host)
774+
return ParallelPlan(cluster_info, self.num_micro_batches,
775+
self.auto_sharding_option, None,
776+
tree_leaves(self.get_input_placement_specs()))
777+
756778
def get_total_allocation_size(self):
757779
"""Get the total allocated memory size of this executable."""
758780
if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):
@@ -1185,7 +1207,8 @@ def get_index_select_mesh_executable(avals, sharding_specs, index, dim,
11851207
as_option = AutoShardingOption()
11861208
strategy_config = StagePlan(global_config.compile_random_seed,
11871209
device_mesh.shape, 1 << 60,
1188-
as_option.all_reduce_threshold, None, -1)
1210+
as_option.all_reduce_threshold,
1211+
AutoShardingOption(), None, -1)
11891212
out_tree = tree_flatten(avals)[1]
11901213
executable = NormalMeshDriverExecutable(device_mesh,
11911214
hlo_module,

alpa/monkey_patch.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax._src.lib.mlir.dialects import mhlo
1111
from jax._src.lib.xla_bridge import get_backend as default_get_backend
1212
from jax.core import Primitive
13-
from jax.interpreters import partial_eval as pe
13+
from jax.interpreters import partial_eval as pe, pxla
1414
from jax.interpreters import xla, mlir
1515
from jax.interpreters.xla import (xops, jaxpr_subcomp, extend_name_stack,
1616
register_translation, wrap_name,
@@ -109,7 +109,6 @@ def _rng_normal_lowering(ctx, mu, sigma, *, shape):
109109
mlir.register_lowering(rng_normal_p, _rng_normal_lowering)
110110

111111

112-
# Monkey patch random generator to use the stateful random generator.
113112
def fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0):
114113
shape = core.as_named_shape(shape)
115114
mu = jnp.asarray(mu, dtype)
@@ -126,6 +125,7 @@ def remove_fold_in(key, data):
126125
return key
127126

128127

128+
# Monkey patch random generator to use the stateful random generator.
129129
jax._src.random.uniform = fast_uniform
130130
jax.random.uniform = fast_uniform
131131
jax._src.random.normal = fast_normal
@@ -136,6 +136,7 @@ def remove_fold_in(key, data):
136136
jax.random.fold_in = remove_fold_in
137137

138138

139+
# Monkey patch remat to use identity instead of while loop
139140
def _zeros(c, xla_shape):
140141
if xla_shape.is_array():
141142
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
@@ -228,6 +229,64 @@ def _remat_translation_rule(ctx,
228229
del dict_val[pe.remat_call_p]
229230
register_translation(pe.remat_call_p, _remat_translation_rule)
230231

232+
233+
# Support pickle ShardingSpec
234+
def sharding_spec_getstate(self):
235+
sharding = []
236+
for x in self.sharding:
237+
if isinstance(x, pxla.NoSharding):
238+
sharding.append((0,))
239+
elif isinstance(x, pxla.Chunked):
240+
sharding.append((1, x.chunks))
241+
elif isinstance(x, pxla.Unstacked):
242+
sharding.append((2, x.size))
243+
else:
244+
raise ValueError(f"Invalid sharding: {x}")
245+
mesh_mapping = []
246+
for x in self.mesh_mapping:
247+
if isinstance(x, pxla.ShardedAxis):
248+
mesh_mapping.append((0, x.axis))
249+
elif isinstance(x, pxla.Replicated):
250+
mesh_mapping.append((1, x.replicas))
251+
else:
252+
raise ValueError(f"Invalid sharding: {x}")
253+
return (sharding, mesh_mapping)
254+
255+
256+
def sharding_spec_setstate(self, state_tuple):
257+
sharding_encoding, mesh_mapping_encoding = state_tuple
258+
259+
sharding = []
260+
for x in sharding_encoding:
261+
if x[0] == 0:
262+
sharding.append(pxla.NoSharding())
263+
elif x[0] == 1:
264+
sharding.append(pxla.Chunked(x[1]))
265+
elif x[0] == 2:
266+
sharding.append(pxla.Unstacked(x[1]))
267+
else:
268+
raise ValueError(f"Invalid sharding: {x}")
269+
270+
mesh_mapping = []
271+
for x in mesh_mapping_encoding:
272+
if x[0] == 0:
273+
mesh_mapping.append(pxla.ShardedAxis(x[1]))
274+
elif x[0] == 1:
275+
mesh_mapping.append(pxla.Replicated(x[1]))
276+
else:
277+
raise ValueError(f"Invalid sharding: {x}")
278+
279+
# pylint: disable=unnecessary-dunder-call
280+
self.__init__(
281+
sharding=sharding,
282+
mesh_mapping=mesh_mapping,
283+
)
284+
285+
286+
setattr(pxla.ShardingSpec, "__getstate__", sharding_spec_getstate)
287+
setattr(pxla.ShardingSpec, "__setstate__", sharding_spec_setstate)
288+
289+
# Monkey patch tree map to disable some warnings
231290
jax._src.tree_util.tree_multimap = jax._src.tree_util.tree_map
232291
jax.tree_multimap = jax._src.tree_util.tree_map
233292

alpa/parallel_method.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class PipeshardParallel(ParallelMethod):
164164
Possible choices are {"manual", alpa.AutoLayerOption,
165165
alpa.ManualLayerOption}
166166
stage_option: Options of grouping layers into pipeline stages.
167-
Possible choices are {"uniform", "auto", alpa.AutoStageOption,,
167+
Possible choices are {"uniform", "auto", alpa.AutoStageOption,
168168
alpa.ManualStageOption}
169169
"""
170170

@@ -178,7 +178,8 @@ def __init__(
178178
stage_option: Optional[Union[StageOption, str]] = None):
179179
self.devices = devices
180180
self.num_micro_batches = num_micro_batches
181-
self.as_option = default_auto_sharding_option or AutoShardingOption()
181+
self.as_option = (default_auto_sharding_option or
182+
AutoShardingOption(prefer_reduce_scatter=True))
182183
self.pipeline_schedule = pipeline_schedule
183184
if layer_option == "manual":
184185
layer_option = ManualLayerOption()

alpa/parallel_plan.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
from jax.core import ShapedArray
1010
from jax.interpreters import pxla
11-
from jax.tree_util import PyTreeDef
1211

1312

1413
@dataclass
@@ -26,23 +25,47 @@ class StagePlan:
2625
logical_mesh_shape: Tuple[int]
2726
all_gather_threshold: int
2827
all_reduce_threshold: int
28+
auto_sharding_option: "AutoShardingOption"
2929
auto_sharding_solution_vector: np.ndarray
3030
auto_sharding_objective: int
3131

3232

3333
@dataclass
3434
class PipelinePlan:
3535
"""The parallel plan for a pipeline."""
36-
forward_stage_layer_ids: Sequence[Sequence[int]]
37-
submesh_physical_shapes: Sequence[Sequence[int]]
38-
submesh_logical_shapes: Sequence[Sequence[int]]
39-
submesh_autosharding_option_dicts: Sequence[dict]
36+
pipeline_schedule: str
37+
layer_option: "LayerOption"
38+
manual_stage_option: "ManualStageOption"
39+
40+
41+
@dataclass
42+
class ClusterInfo:
43+
num_hosts: int
44+
num_devices_per_host: int
4045

4146

4247
@dataclass
4348
class ParallelPlan:
4449
"""The global parallel plan."""
50+
cluster_info: ClusterInfo
51+
num_micro_batches: int
52+
auto_sharding_option: "AutoShardingOption"
4553
pipeline_plan: PipelinePlan
46-
stage_plans: Sequence[StagePlan]
47-
input_placement: PyTreeDef
48-
version: str
54+
input_placement_specs: Sequence[PlacementSpec]
55+
56+
57+
def plan_to_method(plan: ParallelPlan) -> "ParallelMethod":
58+
"""Convert a parallel plan to a parallel method."""
59+
# pylint: disable=import-outside-toplevel
60+
from alpa.parallel_method import ShardParallel, PipeshardParallel
61+
62+
if plan.pipeline_plan is None:
63+
return ShardParallel(num_micro_batches=plan.num_micro_batches,
64+
auto_sharding_option=plan.auto_sharding_option)
65+
else:
66+
return PipeshardParallel(
67+
num_micro_batches=plan.num_micro_batches,
68+
default_auto_sharding_option=plan.auto_sharding_option,
69+
pipeline_schedule=plan.pipeline_plan.pipeline_schedule,
70+
layer_option=plan.pipeline_plan.layer_option,
71+
stage_option=plan.pipeline_plan.manual_stage_option)

alpa/pipeline_parallel/compile_executable.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def compile_pipeshard_executable(
8282
mesh_group=virtual_mesh.launched_physical_mesh_group,
8383
pipeshard_config=pipeshard_config,
8484
num_batch=num_microbatch,
85+
layer_option=layer_option,
8586
in_tree=in_tree,
8687
out_tree=out_tree_thunk(),
8788
static_argnums=static_argnums)
@@ -146,8 +147,7 @@ def compile_pipeshard_executable_internal(
146147

147148
# Construct pipeline stages by merging layers
148149
(jax_pipeline_stages, stage_to_mesh, sliced_virtual_meshes,
149-
logical_mesh_shapes,
150-
autosharding_option_dicts) = cluster_layers_and_slice_mesh(
150+
manual_stage_option) = cluster_layers_and_slice_mesh(
151151
jax_pipeline_layers, virtual_mesh, donation_mapping, acc_grad_outvars,
152152
num_microbatch, micro_batch_size, jax_apply_layers,
153153
apply_grad_global_info, pipeline_schedule, default_as_option,
@@ -202,7 +202,8 @@ def compile_pipeshard_executable_internal(
202202
xla_stages, total_flops = shard_each_stage(
203203
jax_all_stages, sliced_virtual_meshes, schedule, n_stages, num_meshes,
204204
grad_in_to_out, global_invars, acc_grad_outvars, donate_invars_dict,
205-
num_microbatch, logical_mesh_shapes, autosharding_option_dicts,
205+
num_microbatch, manual_stage_option.submesh_logical_shapes,
206+
manual_stage_option.submesh_autosharding_option_dicts,
206207
default_as_option, output_sharding_dict, name_base, gensym_func)
207208
total_flops *= num_microbatch
208209
debug_compilation_time("shard stages")
@@ -224,6 +225,8 @@ def compile_pipeshard_executable_internal(
224225
schedule=schedule,
225226
is_batch=batch_invars,
226227
num_batch=num_microbatch,
228+
default_auto_sharding_option=default_as_option,
229+
manual_stage_option=manual_stage_option,
227230
flop_count=total_flops).compile()
228231

229232
debug_compilation_time("runtime emitter")

alpa/pipeline_parallel/computation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,
2525
run_spmd_partitioner_pass,
2626
get_input_output_sharding_specs,
27-
hlo_sharding_to_sharding_spec)
27+
hlo_sharding_to_sharding_spec,
28+
AutoShardingOption)
2829
from alpa.global_env import global_config
2930
from alpa.util import (OrderedSet, clone_jaxpr, get_compile_options,
3031
jaxpr_to_hlo_module, setup_computation_alias,
@@ -214,7 +215,8 @@ def dummy_computation(cls, name, logical_mesh_shape, gensym_func):
214215
backend_name = "gpu"
215216
backend = xb.get_backend(backend_name)
216217
stage_plan = StagePlan(global_config.compile_random_seed,
217-
logical_mesh_shape, 1, 1, None, 0)
218+
logical_mesh_shape, 1, 1, AutoShardingOption(),
219+
None, 0)
218220
compiled = compile_dummy_zero_constant(backend,
219221
np.prod(logical_mesh_shape))
220222
sharding_annotated_module = compiled.hlo_modules()[0]

0 commit comments

Comments
 (0)