99import hashlib
1010import inspect
1111import random
12+ import re
1213
1314
1415class AssignmentError (Exception ):
@@ -17,6 +18,118 @@ class AssignmentError(Exception):
1718 pass
1819
1920
21+ def _create_grouped_entry (parent_path : str , group : list ) -> dict :
22+ """
23+ Create a single config entry for a group of consecutive layers.
24+ """
25+ if len (group ) == 1 :
26+ # Single layer, return as-is
27+ _ , path , cfg = group [0 ]
28+ return {path : cfg }
29+
30+ # Multiple layers - create grouped entry
31+ layer_indices = [idx for idx , _ , _ in group ]
32+ paths = [path for _ , path , _ in group ]
33+ configs = [cfg for _ , _ , cfg in group ]
34+
35+ start_idx = min (layer_indices )
36+ end_idx = max (layer_indices )
37+
38+ # Use range notation in the key
39+ grouped_path = f"{ parent_path } { start_idx } -{ end_idx } "
40+
41+ # Merge configurations
42+ total_memory = sum (cfg .get ("memory" , 0 ) for cfg in configs )
43+ worker = configs [0 ]["assigned_workers" ][0 ]
44+
45+ grouped_config = {
46+ "type" : "offloaded_group" ,
47+ "name" : configs [0 ].get ("name" , "" ),
48+ "assigned_workers" : [worker ],
49+ "layer_range" : (start_idx , end_idx ),
50+ "layer_paths" : paths ,
51+ "memory" : total_memory ,
52+ "module" : configs [0 ].get ("module" , "" ),
53+ "training" : configs [0 ].get ("training" , False ),
54+ "optimizer_type" : configs [0 ].get ("optimizer_type" , "adam" ),
55+ "num_layers" : len (group ),
56+ }
57+
58+ # Preserve parent_forward_code if present
59+ if "parent_forward_code" in configs [0 ]:
60+ grouped_config ["parent_forward_code" ] = configs [0 ]["parent_forward_code" ]
61+ grouped_config ["parent_module_path" ] = configs [0 ]["parent_module_path" ]
62+
63+ return {grouped_path : grouped_config }
64+
65+
66+ def _group_sequential_layers (config : dict ) -> dict :
67+ """
68+ Group consecutive layers assigned to the same worker into single entries.
69+
70+ For example:
71+ model.layers.0 -> worker1
72+ model.layers.1 -> worker1
73+ model.layers.2 -> worker1
74+
75+ Becomes:
76+ model.layers.0-2 -> worker1
77+ """
78+ # Group paths by their parent and extract layer patterns
79+ layer_groups = defaultdict (list )
80+
81+ for path , cfg in config .items ():
82+ if cfg .get ("type" ) != "offloaded" :
83+ continue
84+
85+ # Match patterns like "model.layers.0", "model.encoder.layer.5", etc.
86+ match = re .match (r'^(.+\.)(\d+)$' , path )
87+ if match :
88+ parent_path = match .group (1 ) # e.g., "model.layers."
89+ layer_idx = int (match .group (2 ))
90+ layer_groups [parent_path ].append ((layer_idx , path , cfg ))
91+
92+ # Create new grouped config
93+ new_config = {}
94+ processed_paths = set ()
95+
96+ for parent_path , layers in layer_groups .items ():
97+ # Sort by layer index
98+ layers .sort (key = lambda x : x [0 ])
99+
100+ # Group consecutive layers with same worker
101+ current_group = []
102+ current_worker = None
103+
104+ for layer_idx , path , cfg in layers :
105+ worker = cfg ["assigned_workers" ][0 ] if cfg ["assigned_workers" ] else None
106+
107+ if worker == current_worker and current_group :
108+ # Extend current group
109+ current_group .append ((layer_idx , path , cfg ))
110+ else :
111+ # Save previous group if exists
112+ if current_group :
113+ new_config .update (_create_grouped_entry (parent_path , current_group ))
114+ processed_paths .update (p for _ , p , _ in current_group )
115+
116+ # Start new group
117+ current_group = [(layer_idx , path , cfg )]
118+ current_worker = worker
119+
120+ # Don't forget the last group
121+ if current_group :
122+ new_config .update (_create_grouped_entry (parent_path , current_group ))
123+ processed_paths .update (p for _ , p , _ in current_group )
124+
125+ # Add all non-layer modules that weren't grouped
126+ for path , cfg in config .items ():
127+ if path not in processed_paths :
128+ new_config [path ] = cfg
129+
130+ return new_config
131+
132+
20133class ModelParser :
21134 def __init__ (self , user_memory : int = 0 ):
22135 self .user_memory = user_memory
@@ -76,6 +189,8 @@ def create_distributed_config(
76189 optimizer_type = optimizer_type ,
77190 )
78191
192+ config = _group_sequential_layers (config )
193+
79194 except AssignmentError :
80195 success = False
81196
@@ -101,7 +216,7 @@ def _recurse_module(
101216 ids = []
102217
103218 memory , breakdown = estimate_memory (
104- module , training , batch_size = 1024 , optimizer_type = optimizer_type
219+ module , training , seq_length = 1024 , optimizer_type = optimizer_type
105220 )
106221
107222 assigned_worker = self ._try_assign_worker (
@@ -241,3 +356,35 @@ def _extract_forward_code(self, module: nn.Module):
241356 f"Could not extract forward code for { module_class .__name__ } : { e } "
242357 )
243358 return None
359+
360+
361+ class ModelSegmentAnalyzer :
362+ """
363+ Analyzes the forward method of a model to identify three key segments:
364+ 1. Pre-offload: Model chunk executed on
365+ """
366+
367+
368+ """
369+ Example workflow
370+
371+
372+ def forward(self, x):
373+ x = self.layer1(x)
374+
375+ for i in range(len(self.layerlist)):
376+ x = self.layerlist[i](x) # if i > 2, worker 2 is used instead
377+
378+
379+ worker1:
380+ x = self.layer1(x)
381+ for i in range(2):
382+ x = self.layerslist[i](x)
383+
384+
385+ worker2:
386+
387+ for i in range(3,5):
388+ x = self.layerslist[i](x)
389+
390+ """
0 commit comments