3
3
from dataclasses import dataclass
4
4
import enum
5
5
import logging
6
- from typing import Any , Callable , Dict , Optional , Sequence , Union
6
+ from typing import Any , Callable , Dict , Optional , Sequence , Union , Set
7
7
8
8
from jax ._src .tree_util import PyTreeDef , tree_unflatten
9
9
from jax .core import Var
@@ -160,19 +160,20 @@ def flatten_uuid_set(container):
160
160
class PipelineInstEmitterHelper :
161
161
"""Environment for PipelineInstEmitter."""
162
162
163
- def __init__ (self , global_invars , grad_dummy_invars , is_batch ,
164
- schedule : PipelineSchedule ):
165
- self . global_invars = global_invars
166
- self .global_batch_invars = OrderedSet (
167
- v for v , b in zip ( global_invars , is_batch ) if b )
163
+ def __init__ (self , global_invar_set : Set [ Var ] ,
164
+ global_batch_invar_set : Set [ Var ],
165
+ grad_dummy_invars : Dict [ Var , Var ], schedule : PipelineSchedule ):
166
+ self .global_invar_set = global_invar_set
167
+ self . global_batch_invar_set = global_batch_invar_set
168
168
self .grad_dummy_invars = grad_dummy_invars
169
169
self .schedule = schedule
170
170
# Dict[var_key -> Dict[mesh_idx -> array_uuid]]
171
171
# The shape of the numpy array is [num_hosts, num_devices_per_host]
172
172
self .env = {}
173
173
174
174
def _get_var_key (self , var , batch_idx ):
175
- if var in self .global_invars and var not in self .global_batch_invars :
175
+ if (var in self .global_invar_set and
176
+ var not in self .global_batch_invar_set ):
176
177
key = (var , 0 )
177
178
elif (var in self .grad_dummy_invars and
178
179
batch_idx != self .schedule .first_backward_batch_index ):
@@ -283,8 +284,12 @@ def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation],
283
284
284
285
##### Internal states #####
285
286
self .uuid_counter = 0 # counter for local buffer uuid
286
- self .env = PipelineInstEmitterHelper (global_invars , grad_dummy_invars ,
287
- is_batch , schedule )
287
+ global_invar_set = OrderedSet (global_invars )
288
+ global_batch_invar_set = OrderedSet (
289
+ v for v , b in zip (global_invars , is_batch ) if b )
290
+ self .env = PipelineInstEmitterHelper (global_invar_set ,
291
+ global_batch_invar_set ,
292
+ grad_dummy_invars , schedule )
288
293
self ._communicator = None
289
294
self ._resharding_tasks = [
290
295
[{} for _ in range (self .num_mesh )] for _ in range (self .num_mesh )
@@ -390,12 +395,8 @@ def compile(self):
390
395
executable_config_lists )
391
396
392
397
# Split input into micro batches
393
- global_batch_invar_set = OrderedSet ([
394
- var for var , batch in zip (self .global_invars , self .is_batch )
395
- if batch
396
- ])
397
- (input_config , input_shard_specs
398
- ) = self ._compile_split_input_to_microbatches (global_batch_invar_set )
398
+ (input_config ,
399
+ input_shard_specs ) = self ._compile_split_input_to_microbatches ()
399
400
400
401
# Simulate the pipeline schedule and generate instructions
401
402
donation_mapping = [DisjointDict () for _ in range (num_mesh )]
@@ -618,7 +619,7 @@ def _compile_grad_buffer_allocations(self, executable_config_lists):
618
619
619
620
return grad_uuids , instruction_lists
620
621
621
- def _compile_collect_mesh_input (self , mesh_idx , batch_vars ):
622
+ def _compile_collect_mesh_input (self , mesh_idx ):
622
623
mesh_arg_set = OrderedSet ()
623
624
var_to_spec = {}
624
625
mesh_batch_vars = OrderedSet ()
@@ -630,9 +631,9 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
630
631
for stage_idx in self .schedule .mesh_stage_mapping [mesh_idx ]:
631
632
stage = self .stages [stage_idx ]
632
633
for spec , invar in zip (stage .input_sharding_specs , stage .invars ):
633
- if invar in self .global_invars :
634
+ if invar in self .env . global_invar_set :
634
635
var_to_spec [invar ] = spec
635
- if invar in batch_vars :
636
+ if invar in self . env . global_batch_invar_set :
636
637
# Split batch arg
637
638
for batch_idx in range (num_batch ):
638
639
mesh_arg_set .add ((invar , batch_idx ))
@@ -666,7 +667,7 @@ def _compile_collect_mesh_input(self, mesh_idx, batch_vars):
666
667
return (mesh_arg_list , mesh_arg_indices , input_shard_indices ,
667
668
input_shard_specs , mesh_invar_is_batch )
668
669
669
- def _compile_split_input_to_microbatches (self , global_batch_invar_set ):
670
+ def _compile_split_input_to_microbatches (self ):
670
671
"""
671
672
Split batch arguments into micro batches.
672
673
@@ -675,10 +676,9 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
675
676
after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1
676
677
"""
677
678
donated_invar_set = OrderedSet ()
678
- global_invar_set = OrderedSet (self .global_invars )
679
679
for stage in self .stages :
680
680
for invar , donate in zip (stage .invars , stage .donated_invars ):
681
- if donate and invar in global_invar_set :
681
+ if donate and invar in self . env . global_invar_set :
682
682
donated_invar_set .add (invar )
683
683
num_mesh = len (self .mesh_group )
684
684
mesh_arg_lists = [None for _ in range (num_mesh )]
@@ -692,13 +692,12 @@ def _compile_split_input_to_microbatches(self, global_batch_invar_set):
692
692
batch_invars = []
693
693
for mesh_idx in range (num_mesh ):
694
694
(mesh_arg_list , arg_indices , shard_indices , shard_specs ,
695
- is_batch ) = self ._compile_collect_mesh_input (
696
- mesh_idx , global_batch_invar_set )
695
+ is_batch ) = self ._compile_collect_mesh_input (mesh_idx )
697
696
698
697
mesh_arg_lists [mesh_idx ] = mesh_arg_list
699
698
delete_after_run = [
700
699
var in donated_invar_set or
701
- (var in global_batch_invar_set and
700
+ (var in self . env . global_batch_invar_set and
702
701
global_config .always_donate_micro_batch_vars )
703
702
for var , _ in mesh_arg_list
704
703
]
0 commit comments