Skip to content

Commit 2f1d2a4

Browse files
committed
Fixed module loop segments and forward pass for partial distributed models. Added relevant layer and module info for worker job requests.
1 parent aff47e0 commit 2f1d2a4

File tree

7 files changed

+608
-241
lines changed

7 files changed

+608
-241
lines changed

tensorlink/ml/graphing.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import hashlib
1010
import inspect
1111
import random
12+
import re
1213

1314

1415
class AssignmentError(Exception):
@@ -17,6 +18,118 @@ class AssignmentError(Exception):
1718
pass
1819

1920

21+
def _create_grouped_entry(parent_path: str, group: list) -> dict:
22+
"""
23+
Create a single config entry for a group of consecutive layers.
24+
"""
25+
if len(group) == 1:
26+
# Single layer, return as-is
27+
_, path, cfg = group[0]
28+
return {path: cfg}
29+
30+
# Multiple layers - create grouped entry
31+
layer_indices = [idx for idx, _, _ in group]
32+
paths = [path for _, path, _ in group]
33+
configs = [cfg for _, _, cfg in group]
34+
35+
start_idx = min(layer_indices)
36+
end_idx = max(layer_indices)
37+
38+
# Use range notation in the key
39+
grouped_path = f"{parent_path}{start_idx}-{end_idx}"
40+
41+
# Merge configurations
42+
total_memory = sum(cfg.get("memory", 0) for cfg in configs)
43+
worker = configs[0]["assigned_workers"][0]
44+
45+
grouped_config = {
46+
"type": "offloaded_group",
47+
"name": configs[0].get("name", ""),
48+
"assigned_workers": [worker],
49+
"layer_range": (start_idx, end_idx),
50+
"layer_paths": paths,
51+
"memory": total_memory,
52+
"module": configs[0].get("module", ""),
53+
"training": configs[0].get("training", False),
54+
"optimizer_type": configs[0].get("optimizer_type", "adam"),
55+
"num_layers": len(group),
56+
}
57+
58+
# Preserve parent_forward_code if present
59+
if "parent_forward_code" in configs[0]:
60+
grouped_config["parent_forward_code"] = configs[0]["parent_forward_code"]
61+
grouped_config["parent_module_path"] = configs[0]["parent_module_path"]
62+
63+
return {grouped_path: grouped_config}
64+
65+
66+
def _group_sequential_layers(config: dict) -> dict:
67+
"""
68+
Group consecutive layers assigned to the same worker into single entries.
69+
70+
For example:
71+
model.layers.0 -> worker1
72+
model.layers.1 -> worker1
73+
model.layers.2 -> worker1
74+
75+
Becomes:
76+
model.layers.0-2 -> worker1
77+
"""
78+
# Group paths by their parent and extract layer patterns
79+
layer_groups = defaultdict(list)
80+
81+
for path, cfg in config.items():
82+
if cfg.get("type") != "offloaded":
83+
continue
84+
85+
# Match patterns like "model.layers.0", "model.encoder.layer.5", etc.
86+
match = re.match(r'^(.+\.)(\d+)$', path)
87+
if match:
88+
parent_path = match.group(1) # e.g., "model.layers."
89+
layer_idx = int(match.group(2))
90+
layer_groups[parent_path].append((layer_idx, path, cfg))
91+
92+
# Create new grouped config
93+
new_config = {}
94+
processed_paths = set()
95+
96+
for parent_path, layers in layer_groups.items():
97+
# Sort by layer index
98+
layers.sort(key=lambda x: x[0])
99+
100+
# Group consecutive layers with same worker
101+
current_group = []
102+
current_worker = None
103+
104+
for layer_idx, path, cfg in layers:
105+
worker = cfg["assigned_workers"][0] if cfg["assigned_workers"] else None
106+
107+
if worker == current_worker and current_group:
108+
# Extend current group
109+
current_group.append((layer_idx, path, cfg))
110+
else:
111+
# Save previous group if exists
112+
if current_group:
113+
new_config.update(_create_grouped_entry(parent_path, current_group))
114+
processed_paths.update(p for _, p, _ in current_group)
115+
116+
# Start new group
117+
current_group = [(layer_idx, path, cfg)]
118+
current_worker = worker
119+
120+
# Don't forget the last group
121+
if current_group:
122+
new_config.update(_create_grouped_entry(parent_path, current_group))
123+
processed_paths.update(p for _, p, _ in current_group)
124+
125+
# Add all non-layer modules that weren't grouped
126+
for path, cfg in config.items():
127+
if path not in processed_paths:
128+
new_config[path] = cfg
129+
130+
return new_config
131+
132+
20133
class ModelParser:
21134
def __init__(self, user_memory: int = 0):
22135
self.user_memory = user_memory
@@ -76,6 +189,8 @@ def create_distributed_config(
76189
optimizer_type=optimizer_type,
77190
)
78191

192+
config = _group_sequential_layers(config)
193+
79194
except AssignmentError:
80195
success = False
81196

@@ -101,7 +216,7 @@ def _recurse_module(
101216
ids = []
102217

103218
memory, breakdown = estimate_memory(
104-
module, training, batch_size=1024, optimizer_type=optimizer_type
219+
module, training, seq_length=1024, optimizer_type=optimizer_type
105220
)
106221

107222
assigned_worker = self._try_assign_worker(
@@ -241,3 +356,35 @@ def _extract_forward_code(self, module: nn.Module):
241356
f"Could not extract forward code for {module_class.__name__}: {e}"
242357
)
243358
return None
359+
360+
361+
class ModelSegmentAnalyzer:
362+
"""
363+
Analyzes the forward method of a model to identify three key segments:
364+
1. Pre-offload: Model chunk executed on
365+
"""
366+
367+
368+
"""
369+
Example workflow
370+
371+
372+
def forward(self, x):
373+
x = self.layer1(x)
374+
375+
for i in range(len(self.layerlist)):
376+
x = self.layerlist[i](x) # if i > 2, worker 2 is used instead
377+
378+
379+
worker1:
380+
x = self.layer1(x)
381+
for i in range(2):
382+
x = self.layerslist[i](x)
383+
384+
385+
worker2:
386+
387+
for i in range(3,5):
388+
x = self.layerslist[i](x)
389+
390+
"""

tensorlink/ml/module.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,8 @@ def distribute_model(self, config=None):
543543
self._load_model_skeleton()
544544
self._wrap_hf_model(config)
545545
else:
546-
self.wrap_module(config)
546+
raise "Custom models are currently not supported."
547+
# self.wrap_module(config)
547548

548549
if len(config) == 1:
549550
module, module_name = access_module(self.model, [-1])
@@ -672,16 +673,16 @@ def wrap_module(self, module_id: list, worker_id):
672673

673674
def _wrap_hf_model(self, config: dict):
674675
# Iterate through each worker and their assigned modules
675-
for module_id, worker_modules in config.items():
676-
worker_id = next(iter(worker_modules.values())).get("assigned_workers")[0]
676+
for module_id, module_info in config.items():
677+
worker_id = module_info["assigned_workers"][0]
677678
file_name = f"{module_id}_{worker_id}.pt"
678-
module_info = str(PreTrainedModel)
679-
offloaded_module = OffloadedModule(self, module_info, worker_id, module_id)
679+
module_name = module_info["module"]
680+
offloaded_module = OffloadedModule(self, module_name, worker_id, module_id)
680681
with open(file_name, "wb") as f:
681682
f.close()
682683

683684
# Spawn a worker thread for the offloaded module
684-
offloaded_module.spawn_worker(file_name)
685+
offloaded_module.spawn_worker(file_name, module_info)
685686
setattr(self, "model", offloaded_module)
686687

687688
def send_request(self, request_type, args):
@@ -793,10 +794,6 @@ def __init__(
793794

794795
self.entire_model = False
795796
self.module_name = module_name.split("(")[0]
796-
try:
797-
self.module_info = module_name.split("(")[1][:-1]
798-
except:
799-
self.module_info = self.module_name
800797

801798
self.parent_model = parent_model
802799
self.worker_id = worker_id
@@ -809,7 +806,7 @@ def children(self):
809806
# Return an empty iterator to hide deeper children
810807
return iter([])
811808

812-
def spawn_worker(self, name):
809+
def spawn_worker(self, name: str, module_info: dict):
813810
# # Initialize a threading Timer to monitor the loading process
814811
# timer = threading.Tier(MAX_WAIT_TIME, self.handle_timeout)
815812
# timer.start()
@@ -818,7 +815,7 @@ def spawn_worker(self, name):
818815
# Send the module to the worker roles
819816

820817
self.parent_model.send_request(
821-
"send_model", (name, self.worker_id, self.module_id)
818+
"send_model", (name, self.worker_id, self.module_id, module_info)
822819
)
823820

824821
# Wait for the module to be loaded on worker

tensorlink/ml/validator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def _manage_auto_loaded_models(self):
284284
"debug_print",
285285
(f"Auto-loading model: {model_name}", "green", logging.INFO),
286286
)
287-
self.models_initializing.add(model_name)
288287
self._initialize_hosted_job(model_name)
289288

290289
# Try to finalize models that are initializing
@@ -313,7 +312,7 @@ def _manage_auto_loaded_models(self):
313312
)
314313
self._remove_hosted_job(model_name)
315314

316-
def inspect_model(self, model_name: str, job_data: dict):
315+
def inspect_model(self, model_name: str, job_data: dict) -> dict:
317316
"""Inspect a model to determine network requirements and store distribution in JSON cache"""
318317
parser = ModelParser()
319318
model_name: str = job_data.get("model_name", model_name)
@@ -350,7 +349,9 @@ def inspect_model(self, model_name: str, job_data: dict):
350349

351350
# Send out job request
352351
try:
353-
self.send_request("send_job_request", job_data)
352+
new_job_data = self.send_request("send_job_request", job_data)
353+
return new_job_data
354+
354355
except Exception as e:
355356
print(str(e))
356357

@@ -541,12 +542,12 @@ def _handle_generate_request(self, request: GenerationRequest):
541542

542543
def _try_finalize_initializing_models(self):
543544
"""Attempt to finalize all models that are currently initializing."""
544-
for model_name in list(self.models_initializing):
545-
if self._finalize_hosted_job(model_name):
545+
for job_id in list(self.models_initializing):
546+
if self._finalize_hosted_job(job_id):
546547
self.send_request(
547548
"debug_print",
548549
(
549-
f"Successfully finalized model: {model_name}",
550+
f"Successfully finalized model: {job_id}",
550551
"green",
551552
logging.INFO,
552553
),
@@ -596,8 +597,8 @@ def _initialize_hosted_job(
596597
}
597598

598599
# Inspect model to determine network requirements
599-
self.inspect_model(model_name, job_data)
600-
600+
job_data = self.inspect_model(model_name, job_data)
601+
self.models_initializing.add(job_data.get("id"))
601602
self.send_request(
602603
"debug_print",
603604
(f"Initialized hosted job for {model_name}", "green", logging.INFO),
@@ -611,11 +612,11 @@ def _initialize_hosted_job(
611612
if model_name in self.model_state:
612613
del self.model_state[model_name]
613614

614-
def _finalize_hosted_job(self, model_name: str):
615+
def _finalize_hosted_job(self, job_id: str):
615616
"""Finalize a hosted job by setting up the distributed model with workers."""
616617
try:
617618
# Check if we have module info ready
618-
args = self.send_request("check_module", None)
619+
args = self.send_request("check_module", job_id)
619620

620621
if not args or not isinstance(args, dict):
621622
# Module not ready yet
@@ -638,7 +639,6 @@ def _finalize_hosted_job(self, model_name: str):
638639

639640
# Register the distributed model's modules
640641
for module_id, module_info in distribution.items():
641-
module_id = hashlib.sha256(json.dumps(module_info).encode()).hexdigest()
642642
self.modules[module_id] = module_info
643643

644644
# Distribute the model across workers

0 commit comments

Comments
 (0)