Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 7b2a023

Browse files
authored
[Fix] Improve compilation speed by using Set instead of List for query (#567)
1 parent 14b3153 commit 7b2a023

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

alpa/pipeline_parallel/runtime_emitter.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
import enum
55
import logging
6-
from typing import Any, Callable, Dict, Optional, Sequence, Union
6+
from typing import Any, Callable, Dict, Optional, Sequence, Union, Set
77

88
from jax._src.tree_util import PyTreeDef, tree_unflatten
99
from jax.core import Var
@@ -160,19 +160,20 @@ def flatten_uuid_set(container):
160160
class PipelineInstEmitterHelper:
161161
"""Environment for PipelineInstEmitter."""
162162

163-
def __init__(self, global_invars, grad_dummy_invars, is_batch,
164-
schedule: PipelineSchedule):
165-
self.global_invars = global_invars
166-
self.global_batch_invars = OrderedSet(
167-
v for v, b in zip(global_invars, is_batch) if b)
163+
def __init__(self, global_invar_set: Set[Var],
164+
global_batch_invar_set: Set[Var],
165+
grad_dummy_invars: Dict[Var, Var], schedule: PipelineSchedule):
166+
self.global_invar_set = global_invar_set
167+
self.global_batch_invar_set = global_batch_invar_set
168168
self.grad_dummy_invars = grad_dummy_invars
169169
self.schedule = schedule
170170
# Dict[var_key -> Dict[mesh_idx -> array_uuid]]
171171
# The shape of the numpy array is [num_hosts, num_devices_per_host]
172172
self.env = {}
173173

174174
def _get_var_key(self, var, batch_idx):
175-
if var in self.global_invars and var not in self.global_batch_invars:
175+
if (var in self.global_invar_set and
176+
var not in self.global_batch_invar_set):
176177
key = (var, 0)
177178
elif (var in self.grad_dummy_invars and
178179
batch_idx != self.schedule.first_backward_batch_index):
@@ -283,8 +284,12 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation],
283284

284285
##### Internal states #####
285286
self.uuid_counter = 0 # counter for local buffer uuid
286-
self.env = PipelineInstEmitterHelper(global_invars, grad_dummy_invars,
287-
is_batch, schedule)
287+
global_invar_set = OrderedSet(global_invars)
288+
global_batch_invar_set = OrderedSet(
289+
v for v, b in zip(global_invars, is_batch) if b)
290+
self.env = PipelineInstEmitterHelper(global_invar_set,
291+
global_batch_invar_set,
292+
grad_dummy_invars, schedule)
288293
self._communicator = None
289294
self._resharding_tasks = [
290295
[{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh)
@@ -390,12 +395,8 @@ def compile(self):
390395
executable_config_lists)
391396

392397
# Split input into micro batches
393-
global_batch_invar_set = OrderedSet([
394-
var for var, batch in zip(self.global_invars, self.is_batch)
395-
if batch
396-
])
397-
(input_config, input_shard_specs
398-
) = self._compile_split_input_to_microbatches(global_batch_invar_set)
398+
(input_config,
399+
input_shard_specs) = self._compile_split_input_to_microbatches()
399400

400401
# Simulate the pipeline schedule and generate instructions
401402
donation_mapping = [DisjointDict() for _ in range(num_mesh)]
@@ -618,7 +619,7 @@ def _compile_grad_buffer_allocations(self, executable_config_lists):
618619

619620
return grad_uuids, instruction_lists
620621

621-
def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
622+
def _compile_collect_mesh_input(self, mesh_idx):
622623
mesh_arg_set = OrderedSet()
623624
var_to_spec = {}
624625
mesh_batch_vars = OrderedSet()
@@ -630,9 +631,9 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
630631
for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]:
631632
stage = self.stages[stage_idx]
632633
for spec, invar in zip(stage.input_sharding_specs, stage.invars):
633-
if invar in self.global_invars:
634+
if invar in self.env.global_invar_set:
634635
var_to_spec[invar] = spec
635-
if invar in batch_vars:
636+
if invar in self.env.global_batch_invar_set:
636637
# Split batch arg
637638
for batch_idx in range(num_batch):
638639
mesh_arg_set.add((invar, batch_idx))
@@ -666,7 +667,7 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
666667
return (mesh_arg_list, mesh_arg_indices, input_shard_indices,
667668
input_shard_specs, mesh_invar_is_batch)
668669

669-
def _compile_split_input_to_microbatches(self, global_batch_invar_set):
670+
def _compile_split_input_to_microbatches(self):
670671
"""
671672
Split batch arguments into micro batches.
672673
@@ -675,10 +676,9 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
675676
after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1
676677
"""
677678
donated_invar_set = OrderedSet()
678-
global_invar_set = OrderedSet(self.global_invars)
679679
for stage in self.stages:
680680
for invar, donate in zip(stage.invars, stage.donated_invars):
681-
if donate and invar in global_invar_set:
681+
if donate and invar in self.env.global_invar_set:
682682
donated_invar_set.add(invar)
683683
num_mesh = len(self.mesh_group)
684684
mesh_arg_lists = [None for _ in range(num_mesh)]
@@ -692,13 +692,12 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
692692
batch_invars = []
693693
for mesh_idx in range(num_mesh):
694694
(mesh_arg_list, arg_indices, shard_indices, shard_specs,
695-
is_batch) = self._compile_collect_mesh_input(
696-
mesh_idx, global_batch_invar_set)
695+
is_batch) = self._compile_collect_mesh_input(mesh_idx)
697696

698697
mesh_arg_lists[mesh_idx] = mesh_arg_list
699698
delete_after_run = [
700699
var in donated_invar_set or
701-
(var in global_batch_invar_set and
700+
(var in self.env.global_batch_invar_set and
702701
global_config.always_donate_micro_batch_vars)
703702
for var, _ in mesh_arg_list
704703
]

0 commit comments

Comments
 (0)