Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
from .stages import (
EdgeProgramManagerTransformStage,
EdgeTransformAndLowerStage,
ExecutorchStage,
PipelineArtifact,
Expand Down Expand Up @@ -315,6 +316,10 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EDGE:
stage = ToEdgeStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.EDGE_PROGRAM_MANAGER_TRANSFORM:
stage = EdgeProgramManagerTransformStage.from_recipe(
self._lowering_recipe
)
elif stage_type == StageType.TO_BACKEND:
stage = ToBackendStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EXECUTORCH:
Expand Down Expand Up @@ -504,7 +509,8 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":
This method checks multiple stages in order of preference:
1. TO_EDGE_TRANSFORM_AND_LOWER (combined stage)
2. TO_BACKEND (separate stage with backend delegation)
3. TO_EDGE (separate stage without backend delegation)
3. EDGE_PROGRAM_MANAGER_TRANSFORM (separate stage after TO_EDGE)
4. TO_EDGE (separate stage without backend delegation)

Returns:
The EdgeProgramManager
Expand All @@ -516,6 +522,7 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":
for stage_type in [
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_BACKEND,
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM,
StageType.TO_EDGE,
]:
artifact = self._stage_to_artifacts.get(stage_type)
Expand All @@ -525,7 +532,7 @@ def get_edge_program_manager(self) -> "EdgeProgramManager":

raise RuntimeError(
"Edge program manager is not available. "
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, or TO_BACKEND."
"Run one of the edge stages first: TO_EDGE_TRANSFORM_AND_LOWER, TO_EDGE, EDGE_PROGRAM_MANAGER_TRANSFORM, or TO_BACKEND."
)

def get_executorch_program(self) -> Program:
Expand Down
16 changes: 11 additions & 5 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from typing import Callable, List, Optional

import torch
from executorch.exir import ExportedProgram
from executorch.exir import EdgeProgramManager, ExportedProgram

from executorch.exir._warnings import experimental

from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.pass_manager import PassType
from executorch.exir.pass_manager import PassManager, PassType
from torchao.core.config import AOBaseConfig
from torchao.quantization.pt2e.quantizer import Quantizer

Expand Down Expand Up @@ -119,14 +119,20 @@ class LoweringRecipe:

Attributes:
partitioners: Optional list of partitioners for model partitioning
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
and return a list of passes (PassType) to be executed during lowering stages.
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram)
and return either List[PassType] or PassManager to be applied during edge lowering.
edge_manager_transform_passes: Optional list of callables that take EdgeProgramManager as argument
and return passes to be applied. Applied sequentially after TO_EDGE stage.
edge_compile_config: Optional edge compilation configuration
"""

partitioners: Optional[List[Partitioner]] = None
edge_transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
) = None
# pyre-ignore[11]: Type not defined
edge_manager_transform_passes: (
None | List[Callable[[EdgeProgramManager], List[PassType] | PassManager]]
) = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None
Expand Down
177 changes: 137 additions & 40 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

import torch
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import EdgeCompileConfig, ExportedProgram
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.pass_manager import PassManager
from executorch.exir.program import to_edge, to_edge_transform_and_lower
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.types import StageType
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
self,
partitioners: Optional[List[Any]] = None,
transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
) = None,
compile_config: Optional[Any] = None,
) -> None:
Expand Down Expand Up @@ -217,28 +218,33 @@ def run(self, artifact: PipelineArtifact) -> None:
constant_methods = artifact.get_context("constant_methods")
generate_etrecord = artifact.get_context("generate_etrecord", False)

# per method transform passes
# Detect if any callable returns PassManager
pass_manager = None
transform_passes = defaultdict(list)
for method_name, ep in exported_programs.items():
# Resolve transform passes from callable
for pass_ in self._transform_passes or []:
if not callable(pass_):
for pass_callable in self._transform_passes or []:
if not callable(pass_callable):
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
"Transform passes must be a callable that resolves to passes"
)
passes = pass_(method_name, ep)
if isinstance(passes, list):
transform_passes[method_name].extend(passes)
passes = pass_callable(method_name, ep)
if isinstance(passes, PassManager):
pass_manager = passes
break
else:
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
)
transform_passes[method_name].extend(passes)
if pass_manager:
break

# Use PassManager directly if found, otherwise use dict
final_passes = pass_manager if pass_manager else transform_passes

with validation_disabled():
edge_program_manager = to_edge_transform_and_lower(
exported_programs,
partitioner=self._partitioners,
transform_passes=transform_passes,
transform_passes=final_passes,
constant_methods=constant_methods,
compile_config=self._compile_config,
generate_etrecord=generate_etrecord,
Expand Down Expand Up @@ -275,6 +281,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:
return [
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_BACKEND,
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM, # Added for server model generation (skipping TO_BACKEND)
]

@property
Expand Down Expand Up @@ -495,76 +502,166 @@ def run(self, artifact: PipelineArtifact) -> None:
self._artifact = artifact.copy_with_new_data(edge_program_manager)


class ToBackendStage(Stage):
class EdgeProgramManagerTransformStage(Stage):
"""
Stage: Apply transformations and partitioning to EdgeProgramManager.
Stage: Apply transformation passes that require EdgeProgramManager.

This stage enables dynamic pass generation where passes need access to the
EdgeProgramManager instance. Passes are applied sequentially, allowing
to control order and dependencies between pass groups.
"""

def __init__(
self,
partitioners: Optional[List[Any]] = None,
transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType]]]
edge_transform_passes: (
None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]]
) = None,
edge_manager_transform_passes: (
None | List[Callable[[EdgeProgramManager], List[PassType] | PassManager]]
) = None,
) -> None:
"""
Initialize the EdgeProgramManagerTransformStage.

Args:
edge_manager_transform_passes: List of callables that take EdgeProgramManager
and return either List[PassType] or PassManager.
Each callable is applied sequentially, allowing
backends to control pass ordering and dependencies.
"""
super().__init__()
self._partitioners = partitioners
self._transform_passes = transform_passes
self._edge_transform_passes = edge_transform_passes or []
self._edge_manager_transform_passes = edge_manager_transform_passes or []

@classmethod
def from_recipe(
cls, lowering_recipe: Optional["LoweringRecipe"]
) -> "ToBackendStage":
cls, lowering_recipe: Optional[LoweringRecipe]
) -> "EdgeProgramManagerTransformStage":
if lowering_recipe is None:
return cls()

return cls(
partitioners=lowering_recipe.partitioners,
transform_passes=lowering_recipe.edge_transform_passes,
edge_transform_passes=lowering_recipe.edge_transform_passes,
edge_manager_transform_passes=lowering_recipe.edge_manager_transform_passes,
)

@property
def stage_type(self) -> str:
return StageType.TO_BACKEND
return StageType.EDGE_PROGRAM_MANAGER_TRANSFORM

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TO_EDGE]
return [
StageType.TO_EDGE,
# StageType.TO_EDGE_TRANSFORM_AND_LOWER, # TODO
]

@property
def can_start_pipeline(self) -> bool:
return False

def run(self, artifact: PipelineArtifact) -> None:
"""
Apply transformations and partitioning to EdgeProgramManager.
Apply transformation passes sequentially.

Args:
artifact: Contains edge program manager and context
artifact: Pipeline artifact containing EdgeProgramManager
"""
edge_program_manager = artifact.data

if edge_program_manager is None:
raise RuntimeError("Edge program manager is not set.")
if not isinstance(edge_program_manager, EdgeProgramManager):
raise TypeError(
f"Expected EdgeProgramManager but got {type(edge_program_manager)}"
)

# per method transform passes
if not self._edge_transform_passes and not self._edge_manager_transform_passes:
self._artifact = artifact
return

# Detect if any callable returns PassManager
pass_manager = None
transform_passes = defaultdict(list)
for method_name in edge_program_manager.methods:
# Resolve transform passes if it's a callable
ep = edge_program_manager.exported_program(method_name)
for pass_ in self._transform_passes or []:
if not callable(pass_):
for pass_callable in self._edge_transform_passes or []:
if not callable(pass_callable):
raise ValueError(
"Transform passes must be a callable that resolves to a list of passes"
"Transform passes must be a callable that resolves to passes"
)
passes = pass_(method_name, ep)
if isinstance(passes, list):
transform_passes[method_name].extend(passes)
passes = pass_callable(method_name, ep)
if isinstance(passes, PassManager):
pass_manager = passes
break
else:
raise ValueError("Transform passes must return list of passes")
transform_passes[method_name].extend(passes)
if pass_manager:
break

# Use PassManager directly if found, otherwise use dict
final_passes = pass_manager if pass_manager else transform_passes

# Apply edge transform passes
edge_program_manager = edge_program_manager.transform(final_passes)

# Run edge manager transform passes
for pass_callable in self._edge_manager_transform_passes:
passes = pass_callable(edge_program_manager)
if passes:
edge_program_manager = edge_program_manager.transform(passes)

self._artifact = artifact.copy_with_new_data(edge_program_manager)


class ToBackendStage(Stage):
"""
Stage: Apply partitioning to EdgeProgramManager.
"""

# Apply transform passes
edge_program_manager = edge_program_manager.transform(transform_passes)
def __init__(
self,
partitioners: Optional[List[Any]] = None,
) -> None:
super().__init__()
self._partitioners = partitioners

@classmethod
def from_recipe(
cls, lowering_recipe: Optional["LoweringRecipe"]
) -> "ToBackendStage":
if lowering_recipe is None:
return cls()

return cls(
partitioners=lowering_recipe.partitioners,
)

@property
def stage_type(self) -> str:
return StageType.TO_BACKEND

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [
StageType.TO_EDGE,
StageType.EDGE_PROGRAM_MANAGER_TRANSFORM,
]

@property
def can_start_pipeline(self) -> bool:
return False

def run(self, artifact: PipelineArtifact) -> None:
"""
Apply partitioning to EdgeProgramManager.

Args:
artifact: Contains edge program manager and context
"""
edge_program_manager = artifact.data

if edge_program_manager is None:
raise RuntimeError("Edge program manager is not set.")

# Apply partitioners if available
if self._partitioners is not None and len(self._partitioners) > 0:
Expand Down
Loading
Loading