Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 164 additions & 28 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from executorch.exir.program import ExecutorchProgramManager
from executorch.exir.schema import Program
from executorch.extension.export_util.utils import save_pte_program
from executorch.runtime import Runtime, Verification
from tabulate import tabulate
from torch import nn
from torch.export import ExportedProgram
from torch.fx import GraphModule

from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
from .stages import (
Expand All @@ -36,11 +37,22 @@
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
)
def export(
model: Union[nn.Module, Dict[str, nn.Module]],
example_inputs: Union[
List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]
model: Union[
nn.Module,
Dict[str, nn.Module],
GraphModule,
Dict[str, GraphModule],
ExportedProgram,
Dict[str, ExportedProgram],
str,
],
export_recipe: ExportRecipe,
example_inputs: Optional[
Union[
List[tuple[torch.Tensor, ...]],
Dict[str, List[tuple[torch.Tensor, ...]]],
]
] = None,
export_recipe: ExportRecipe = None,
name: Optional[str] = None,
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
Expand All @@ -54,10 +66,16 @@ def export(
optionally run the export process in one step.

Args:
model: The PyTorch model(s) to export, either a single model or a dictionary
mapping method names to models
model: The PyTorch model(s) to export. Can be:
- nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
- GraphModule or Dict[str, GraphModule]: Quantized model(s) (e.g., from prepare/convert)
- ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
- str: Path to load an ExportedProgram from disk
example_inputs: Example inputs for the model(s), either a list of input tuples
or a dictionary mapping method names to lists of input tuples
or a dictionary mapping method names to lists of input tuples.
First sample (index 0) is used for torch.export.export() to export the model.
All samples are used as calibration dataset in PT2E Quantize stage.
Optional when model is ExportedProgram (not needed).
export_recipe: Contains the configuration for the export process
name: Optional name for the export
dynamic_shapes: Optional dynamic shape specifications
Expand Down Expand Up @@ -99,11 +117,22 @@ class ExportSession:

def __init__(
self,
model: Union[nn.Module, Dict[str, nn.Module]],
example_inputs: Union[
List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]
model: Union[
nn.Module,
Dict[str, nn.Module],
GraphModule,
Dict[str, GraphModule],
ExportedProgram,
Dict[str, ExportedProgram],
str,
],
export_recipe: ExportRecipe,
example_inputs: Optional[
Union[
List[tuple[torch.Tensor, ...]],
Dict[str, List[tuple[torch.Tensor, ...]]],
]
] = None,
export_recipe: ExportRecipe = None,
name: Optional[str] = None,
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
Expand All @@ -114,26 +143,52 @@ def __init__(
Initialize the ExportSession with model, inputs, and recipe.

Args:
model: The PyTorch model(s) to export, either a single model or a dictionary
mapping method names to models
model: The PyTorch model(s) to export. Can be:
- nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
- GraphModule or Dict[str, GraphModule]: Quantized model(s)
- ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
- str: Path to load an ExportedProgram from disk
example_inputs: Example inputs for the model(s), either a list of input tuples
or a dictionary mapping method names to lists of input tuples
or a dictionary mapping method names to lists of input tuples.
First sample (index 0) is used for torch.export.export() to export the model.
All samples are used as calibration dataset in PT2E Quantize stage,
Optional when model is ExportedProgram (not needed).
export_recipe: Contains the configuration for the export process
name: Optional name for the export
dynamic_shapes: Optional dynamic shape specifications
constant_methods: Optional dictionary of constant methods
artifact_dir: Optional directory to store artifacts
generate_etrecord: Optional flag to generate an etrecord
"""
# Load model from file if string path provided
if isinstance(model, str):
model = torch.export.load(model)
logging.info(f"Loaded ExportedProgram from {model}")

# Detect input model type to determine which stages to skip
self._input_model_type = self._detect_model_type(model)

# Standardize model to dictionary format
self._model = model if isinstance(model, dict) else {"forward": model}

# Standardize example_inputs to dictionary format
self._example_inputs = (
example_inputs
if isinstance(example_inputs, dict)
else {"forward": example_inputs}
)
# Validate and standardize example_inputs to dictionary format
# example_inputs not required for ExportedProgram input
if self._input_model_type == "ExportedProgram":
self._example_inputs = example_inputs or {}
if isinstance(self._example_inputs, list):
self._example_inputs = {"forward": self._example_inputs}
else:
# For nn.Module and GraphModule, example_inputs are required
if example_inputs is None:
raise ValueError(
f"example_inputs are required when model is {self._input_model_type}. "
f"Only ExportedProgram inputs can omit example_inputs."
)
self._example_inputs = (
example_inputs
if isinstance(example_inputs, dict)
else {"forward": example_inputs}
)

# Standardize dynamic_shapes to dictionary format
self._dynamic_shapes = {}
Expand Down Expand Up @@ -176,14 +231,64 @@ def __init__(

self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}

def _detect_model_type(
self, model: Union[nn.Module, GraphModule, ExportedProgram, Dict]
) -> str:
"""
Detect the type of input model.

Args:
model: Input model in various formats

Returns:
String indicating the model type: "nn.Module", "GraphModule", or "ExportedProgram"
"""
# Handle dict (multi-method) - check first value
if isinstance(model, dict):
first_value = next(iter(model.values()))
return self._detect_model_type(first_value)

# Detect single model type
if isinstance(model, ExportedProgram):
return "ExportedProgram"
elif isinstance(model, GraphModule):
return "GraphModule"
elif isinstance(model, nn.Module):
return "nn.Module"
else:
raise TypeError(f"Unsupported model type: {type(model)}")

def _get_default_pipeline(self) -> List[StageType]:
return [
StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
]
"""
Get default pipeline stages based on input model type.

Returns:
List of stages appropriate for the input model type
"""
stages = []

# Add quantization stages only for eager nn.Module
if self._input_model_type == "nn.Module":
stages.extend(
[
StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
]
)

# Add torch export stage if not already exported
if self._input_model_type != "ExportedProgram":
stages.append(StageType.TORCH_EXPORT)

# Always include edge and executorch stages
stages.extend(
[
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
]
)

return stages

def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
"""Build the stage registry from the given stages."""
Expand Down Expand Up @@ -259,6 +364,34 @@ def _validate_pipeline_sequence(
if not stages:
raise ValueError("Pipeline stages cannot be empty")

# Validate pipeline compatibility with input model type
if self._input_model_type == "GraphModule":
# GraphModule input should not run quantization stages
incompatible_stages = {StageType.SOURCE_TRANSFORM, StageType.QUANTIZE}
found_incompatible = set(stages) & incompatible_stages
if found_incompatible:
stage_names = ", ".join(s.name for s in found_incompatible)
raise ValueError(
f"Cannot run {stage_names} stage(s) with GraphModule input. "
f"GraphModule is already quantized. "
f"Remove {stage_names} from pipeline_stages or use nn.Module input."
)
elif self._input_model_type == "ExportedProgram":
# ExportedProgram input should not run quantization or torch export stages
incompatible_stages = {
StageType.SOURCE_TRANSFORM,
StageType.QUANTIZE,
StageType.TORCH_EXPORT,
}
found_incompatible = set(stages) & incompatible_stages
if found_incompatible:
stage_names = ", ".join(s.name for s in found_incompatible)
raise ValueError(
f"Cannot run {stage_names} stage(s) with ExportedProgram input. "
f"ExportedProgram is already exported. "
f"Remove {stage_names} from pipeline_stages or use nn.Module/GraphModule input."
)

# Validate that the first stage can start a pipeline
first_stage = stages[0]
first_stage_instance = self._stage_registry.get(first_stage)
Expand Down Expand Up @@ -436,6 +569,9 @@ def run_method(
Raises:
RuntimeError: If the method cannot be loaded
"""
# Lazy import to avoid forcing portable_lib dependency at module load time
from executorch.runtime import Runtime, Verification

et_runtime = Runtime.get()
program = et_runtime.load_program(
self.get_pte_buffer(), verification=Verification.Minimal
Expand Down
10 changes: 7 additions & 3 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
) = None,
compile_config: Optional[Any] = None,
) -> None:
super().__init__()
self._partitioners = partitioners
self._transform_passes = transform_passes
self._compile_config = compile_config
Expand Down Expand Up @@ -206,7 +207,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:

@property
def can_start_pipeline(self) -> bool:
return False
return True

def run(self, artifact: PipelineArtifact) -> None:
"""
Expand Down Expand Up @@ -271,7 +272,10 @@ def stage_type(self) -> str:

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND]
return [
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_BACKEND,
]

@property
def can_start_pipeline(self) -> bool:
Expand Down Expand Up @@ -468,7 +472,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:

@property
def can_start_pipeline(self) -> bool:
return False
return True

def run(self, artifact: PipelineArtifact) -> None:
"""
Expand Down
Loading
Loading