Skip to content

Commit 4890eef

Browse files
[Executorch][Export][4/N] Add edge program transform stage which runs after to_edge() stage
Differential Revision: D87576717 Pull Request resolved: #16129
1 parent acabda8 commit 4890eef

File tree

5 files changed

+159
-103
lines changed

5 files changed

+159
-103
lines changed

export/export.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2424
from .stages import (
25+
EdgeProgramManagerTransformStage,
2526
EdgeTransformAndLowerStage,
2627
ExecutorchStage,
2728
PipelineArtifact,
@@ -315,6 +316,10 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
315316
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
316317
elif stage_type == StageType.TO_EDGE:
317318
stage = ToEdgeStage.from_recipe(self._lowering_recipe)
319+
elif stage_type == StageType.EDGE_PROGRAM_MANAGER_TRANSFORM:
320+
stage = EdgeProgramManagerTransformStage.from_recipe(
321+
self._lowering_recipe
322+
)
318323
elif stage_type == StageType.TO_BACKEND:
319324
stage = ToBackendStage.from_recipe(self._lowering_recipe)
320325
elif stage_type == StageType.TO_EXECUTORCH:
@@ -504,7 +509,8 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":
504509
This method checks multiple stages in order of preference:
505510
1. TO_EDGE_TRANSFORM_AND_LOWER (combined stage)
506511
2. TO_BACKEND (separate stage with backend delegation)
507-
3. TO_EDGE (separate stage without backend delegation)
512+
3. EDGE_PROGRAM_MANAGER_TRANSFORM (separate stage after TO_EDGE)
513+
4. TO_EDGE (separate stage without backend delegation)
508514
509515
Returns:
510516
The EdgeProgramManager
@@ -516,6 +522,7 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":
516522
for stage_type in [
517523
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
518524
StageType.TO_BACKEND,
525+
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM,
519526
StageType.TO_EDGE,
520527
]:
521528
artifact = self._stage_to_artifacts.get(stage_type)
@@ -525,7 +532,7 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":
525532

526533
raise RuntimeError(
527534
"Edge program manager is not available. "
528-
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, or TO_BACKEND."
535+
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, EDGE_PROGRAM_MANAGER_TRANSFORM, or TO_BACKEND."
529536
)
530537

531538
def get_executorch_program(self) -> Program:

export/recipe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from typing import Callable, List, Optional
1212

1313
import torch
14-
from executorch.exir import ExportedProgram
14+
from executorch.exir import EdgeProgramManager, ExportedProgram
1515

1616
from executorch.exir._warnings import experimental
1717

1818
from executorch.exir.backend.partitioner import Partitioner
1919
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
20-
from executorch.exir.pass_manager import PassType
20+
from executorch.exir.pass_manager import PassManager, PassType
2121
from torchao.core.config import AOBaseConfig
2222
from torchao.quantization.pt2e.quantizer import Quantizer
2323

@@ -119,14 +119,20 @@ class LoweringRecipe:
119119
120120
Attributes:
121121
partitioners: Optional list of partitioners for model partitioning
122-
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
123-
and return a list of passes (PassType) to be executed during lowering stages.
122+
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram)
123+
and return either List[PassType] or PassManager to be applied during edge lowering.
124+
edge_manager_transform_passes: Optional list of callables that take EdgeProgramManager as argument
125+
and return passes to be applied. Applied sequentially after TO_EDGE stage.
124126
edge_compile_config: Optional edge compilation configuration
125127
"""
126128

127129
partitioners: Optional[List[Partitioner]] = None
128130
edge_transform_passes: (
129-
None | List[Callable[[str, ExportedProgram], List[PassType]]]
131+
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
132+
) = None
133+
# pyre-ignore[11]: Type not defined
134+
edge_manager_transform_passes: (
135+
None | List[Callable[[EdgeProgramManager], List[PassType] | PassManager]]
130136
) = None
131137
# pyre-ignore[11]: Type not defined
132138
edge_compile_config: Optional[EdgeCompileConfig] = None

export/stages.py

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
import torch
1515
from executorch.devtools.backend_debug import get_delegation_info
16-
from executorch.exir import EdgeCompileConfig, ExportedProgram
16+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram
1717
from executorch.exir.backend.backend_api import validation_disabled
18+
from executorch.exir.pass_manager import PassManager
1819
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1920
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
2021
from executorch.export.types import StageType
@@ -175,7 +176,7 @@ def __init__(
175176
self,
176177
partitioners: Optional[List[Any]] = None,
177178
transform_passes: (
178-
None | List[Callable[[str, ExportedProgram], List[PassType]]]
179+
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
179180
) = None,
180181
compile_config: Optional[Any] = None,
181182
) -> None:
@@ -217,28 +218,33 @@ def run(self, artifact: PipelineArtifact) -> None:
217218
constant_methods = artifact.get_context("constant_methods")
218219
generate_etrecord = artifact.get_context("generate_etrecord", False)
219220

220-
# per method transform passes
221+
# Detect if any callable returns PassManager
222+
pass_manager = None
221223
transform_passes = defaultdict(list)
222224
for method_name, ep in exported_programs.items():
223225
# Resolve transform passes from callable
224-
for pass_ in self._transform_passes or []:
225-
if not callable(pass_):
226+
for pass_callable in self._transform_passes or []:
227+
if not callable(pass_callable):
226228
raise ValueError(
227-
"Transform passes must be a callable that resolves to a list of passes"
229+
"Transform passes must be a callable that resolves to passes"
228230
)
229-
passes = pass_(method_name, ep)
230-
if isinstance(passes, list):
231-
transform_passes[method_name].extend(passes)
231+
passes = pass_callable(method_name, ep)
232+
if isinstance(passes, PassManager):
233+
pass_manager = passes
234+
break
232235
else:
233-
raise ValueError(
234-
"Transform passes must be a callable that resolves to a list of passes"
235-
)
236+
transform_passes[method_name].extend(passes)
237+
if pass_manager:
238+
break
239+
240+
# Use PassManager directly if found, otherwise use dict
241+
final_passes = pass_manager if pass_manager else transform_passes
236242

237243
with validation_disabled():
238244
edge_program_manager = to_edge_transform_and_lower(
239245
exported_programs,
240246
partitioner=self._partitioners,
241-
transform_passes=transform_passes,
247+
transform_passes=final_passes,
242248
constant_methods=constant_methods,
243249
compile_config=self._compile_config,
244250
generate_etrecord=generate_etrecord,
@@ -275,6 +281,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:
275281
return [
276282
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
277283
StageType.TO_BACKEND,
284+
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM, # Added for server model generation (skipping TO_BACKEND)
278285
]
279286

280287
@property
@@ -495,76 +502,166 @@ def run(self, artifact: PipelineArtifact) -> None:
495502
self._artifact = artifact.copy_with_new_data(edge_program_manager)
496503

497504

498-
class ToBackendStage(Stage):
505+
class EdgeProgramManagerTransformStage(Stage):
499506
"""
500-
Stage: Apply transformations and partitioning to EdgeProgramManager.
507+
Stage: Apply transformation passes that require EdgeProgramManager.
508+
509+
This stage enables dynamic pass generation where passes need access to the
510+
EdgeProgramManager instance. Passes are applied sequentially, allowing
511+
to control order and dependencies between pass groups.
501512
"""
502513

503514
def __init__(
504515
self,
505-
partitioners: Optional[List[Any]] = None,
506-
transform_passes: (
507-
None | List[Callable[[str, ExportedProgram], List[PassType]]]
516+
edge_transform_passes: (
517+
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
518+
) = None,
519+
edge_manager_transform_passes: (
520+
None | List[Callable[[EdgeProgramManager], List[PassType] | PassManager]]
508521
) = None,
509522
) -> None:
523+
"""
524+
Initialize the EdgeProgramManagerTransformStage.
525+
526+
Args:
527+
edge_manager_transform_passes: List of callables that take EdgeProgramManager
528+
and return either List[PassType] or PassManager.
529+
Each callable is applied sequentially, allowing
530+
backends to control pass ordering and dependencies.
531+
"""
510532
super().__init__()
511-
self._partitioners = partitioners
512-
self._transform_passes = transform_passes
533+
self._edge_transform_passes = edge_transform_passes or []
534+
self._edge_manager_transform_passes = edge_manager_transform_passes or []
513535

514536
@classmethod
515537
def from_recipe(
516-
cls, lowering_recipe: Optional["LoweringRecipe"]
517-
) -> "ToBackendStage":
538+
cls, lowering_recipe: Optional[LoweringRecipe]
539+
) -> "EdgeProgramManagerTransformStage":
518540
if lowering_recipe is None:
519541
return cls()
520542

521543
return cls(
522-
partitioners=lowering_recipe.partitioners,
523-
transform_passes=lowering_recipe.edge_transform_passes,
544+
edge_transform_passes=lowering_recipe.edge_transform_passes,
545+
edge_manager_transform_passes=lowering_recipe.edge_manager_transform_passes,
524546
)
525547

526548
@property
527549
def stage_type(self) -> str:
528-
return StageType.TO_BACKEND
550+
return StageType.EDGE_PROGRAM_MANAGER_TRANSFORM
529551

530552
@property
531553
def valid_predecessor_stages(self) -> List["StageType"]:
532-
return [StageType.TO_EDGE]
554+
return [
555+
StageType.TO_EDGE,
556+
# StageType.TO_EDGE_TRANSFORM_AND_LOWER, # TODO
557+
]
533558

534559
@property
535560
def can_start_pipeline(self) -> bool:
536561
return False
537562

538563
def run(self, artifact: PipelineArtifact) -> None:
539564
"""
540-
Apply transformations and partitioning to EdgeProgramManager.
565+
Apply transformation passes sequentially.
541566
542567
Args:
543-
artifact: Contains edge program manager and context
568+
artifact: Pipeline artifact containing EdgeProgramManager
544569
"""
545570
edge_program_manager = artifact.data
546571

547-
if edge_program_manager is None:
548-
raise RuntimeError("Edge program manager is not set.")
572+
if not isinstance(edge_program_manager, EdgeProgramManager):
573+
raise TypeError(
574+
f"Expected EdgeProgramManager but got {type(edge_program_manager)}"
575+
)
549576

550-
# per method transform passes
577+
if not self._edge_transform_passes and not self._edge_manager_transform_passes:
578+
self._artifact = artifact
579+
return
580+
581+
# Detect if any callable returns PassManager
582+
pass_manager = None
551583
transform_passes = defaultdict(list)
552584
for method_name in edge_program_manager.methods:
553585
# Resolve transform passes if it's a callable
554586
ep = edge_program_manager.exported_program(method_name)
555-
for pass_ in self._transform_passes or []:
556-
if not callable(pass_):
587+
for pass_callable in self._edge_transform_passes or []:
588+
if not callable(pass_callable):
557589
raise ValueError(
558-
"Transform passes must be a callable that resolves to a list of passes"
590+
"Transform passes must be a callable that resolves to passes"
559591
)
560-
passes = pass_(method_name, ep)
561-
if isinstance(passes, list):
562-
transform_passes[method_name].extend(passes)
592+
passes = pass_callable(method_name, ep)
593+
if isinstance(passes, PassManager):
594+
pass_manager = passes
595+
break
563596
else:
564-
raise ValueError("Transform passes must return list of passes")
597+
transform_passes[method_name].extend(passes)
598+
if pass_manager:
599+
break
600+
601+
# Use PassManager directly if found, otherwise use dict
602+
final_passes = pass_manager if pass_manager else transform_passes
603+
604+
# Apply edge transform passes
605+
edge_program_manager = edge_program_manager.transform(final_passes)
606+
607+
# Run edge manager transform passes
608+
for pass_callable in self._edge_manager_transform_passes:
609+
passes = pass_callable(edge_program_manager)
610+
if passes:
611+
edge_program_manager = edge_program_manager.transform(passes)
612+
613+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
614+
615+
616+
class ToBackendStage(Stage):
617+
"""
618+
Stage: Apply partitioning to EdgeProgramManager.
619+
"""
565620

566-
# Apply transform passes
567-
edge_program_manager = edge_program_manager.transform(transform_passes)
621+
def __init__(
622+
self,
623+
partitioners: Optional[List[Any]] = None,
624+
) -> None:
625+
super().__init__()
626+
self._partitioners = partitioners
627+
628+
@classmethod
629+
def from_recipe(
630+
cls, lowering_recipe: Optional["LoweringRecipe"]
631+
) -> "ToBackendStage":
632+
if lowering_recipe is None:
633+
return cls()
634+
635+
return cls(
636+
partitioners=lowering_recipe.partitioners,
637+
)
638+
639+
@property
640+
def stage_type(self) -> str:
641+
return StageType.TO_BACKEND
642+
643+
@property
644+
def valid_predecessor_stages(self) -> List["StageType"]:
645+
return [
646+
StageType.TO_EDGE,
647+
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM,
648+
]
649+
650+
@property
651+
def can_start_pipeline(self) -> bool:
652+
return False
653+
654+
def run(self, artifact: PipelineArtifact) -> None:
655+
"""
656+
Apply partitioning to EdgeProgramManager.
657+
658+
Args:
659+
artifact: Contains edge program manager and context
660+
"""
661+
edge_program_manager = artifact.data
662+
663+
if edge_program_manager is None:
664+
raise RuntimeError("Edge program manager is not set.")
568665

569666
# Apply partitioners if available
570667
if self._partitioners is not None and len(self._partitioners) > 0:

0 commit comments

Comments
 (0)