Skip to content

Commit 78f03bd

Browse files
[Executorch][Export][1/n] Add exported program and GraphModule as inputs, validation and tests
Pull Request resolved: #16126 Add support for `ExportedProgram` and `GraphModule` as inputs to `ExportSession` with proper validation. - This allows users to pass already Quantized or exported programs directly to `ExportSession`, skipping the quantization or export stages as needed. - Includes comprehensive test coverage for all input types. ghstack-source-id: 327669279 @exported-using-ghexport Differential Revision: [D87576716](https://our.internmc.facebook.com/intern/diff/D87576716/)
1 parent 53c6ad2 commit 78f03bd

File tree

3 files changed

+509
-37
lines changed

3 files changed

+509
-37
lines changed

export/export.py

Lines changed: 164 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from executorch.exir.program import ExecutorchProgramManager
1414
from executorch.exir.schema import Program
1515
from executorch.extension.export_util.utils import save_pte_program
16-
from executorch.runtime import Runtime, Verification
1716
from tabulate import tabulate
1817
from torch import nn
18+
from torch.export import ExportedProgram
19+
from torch.fx import GraphModule
1920

2021
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2122
from .stages import (
@@ -36,11 +37,22 @@
3637
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
3738
)
3839
def export(
39-
model: Union[nn.Module, Dict[str, nn.Module]],
40-
example_inputs: Union[
41-
List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]
40+
model: Union[
41+
nn.Module,
42+
Dict[str, nn.Module],
43+
GraphModule,
44+
Dict[str, GraphModule],
45+
ExportedProgram,
46+
Dict[str, ExportedProgram],
47+
str,
4248
],
43-
export_recipe: ExportRecipe,
49+
example_inputs: Optional[
50+
Union[
51+
List[tuple[torch.Tensor, ...]],
52+
Dict[str, List[tuple[torch.Tensor, ...]]],
53+
]
54+
] = None,
55+
export_recipe: ExportRecipe = None,
4456
name: Optional[str] = None,
4557
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
4658
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
@@ -54,10 +66,16 @@ def export(
5466
optionally run the export process in one step.
5567
5668
Args:
57-
model: The PyTorch model(s) to export, either a single model or a dictionary
58-
mapping method names to models
69+
model: The PyTorch model(s) to export. Can be:
70+
- nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
71+
- GraphModule or Dict[str, GraphModule]: Quantized model(s) (e.g., from prepare/convert)
72+
- ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
73+
- str: Path to load an ExportedProgram from disk
5974
example_inputs: Example inputs for the model(s), either a list of input tuples
60-
or a dictionary mapping method names to lists of input tuples
75+
or a dictionary mapping method names to lists of input tuples.
76+
First sample (index 0) is used for torch.export.export() to export the model.
77+
All samples are used as calibration dataset in PT2E Quantize stage.
78+
Optional when model is ExportedProgram (not needed).
6179
export_recipe: Contains the configuration for the export process
6280
name: Optional name for the export
6381
dynamic_shapes: Optional dynamic shape specifications
@@ -99,11 +117,22 @@ class ExportSession:
99117

100118
def __init__(
101119
self,
102-
model: Union[nn.Module, Dict[str, nn.Module]],
103-
example_inputs: Union[
104-
List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]
120+
model: Union[
121+
nn.Module,
122+
Dict[str, nn.Module],
123+
GraphModule,
124+
Dict[str, GraphModule],
125+
ExportedProgram,
126+
Dict[str, ExportedProgram],
127+
str,
105128
],
106-
export_recipe: ExportRecipe,
129+
example_inputs: Optional[
130+
Union[
131+
List[tuple[torch.Tensor, ...]],
132+
Dict[str, List[tuple[torch.Tensor, ...]]],
133+
]
134+
] = None,
135+
export_recipe: ExportRecipe = None,
107136
name: Optional[str] = None,
108137
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
109138
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
@@ -114,26 +143,52 @@ def __init__(
114143
Initialize the ExportSession with model, inputs, and recipe.
115144
116145
Args:
117-
model: The PyTorch model(s) to export, either a single model or a dictionary
118-
mapping method names to models
146+
model: The PyTorch model(s) to export. Can be:
147+
- nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
148+
- GraphModule or Dict[str, GraphModule]: Quantized model(s)
149+
- ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
150+
- str: Path to load an ExportedProgram from disk
119151
example_inputs: Example inputs for the model(s), either a list of input tuples
120-
or a dictionary mapping method names to lists of input tuples
152+
or a dictionary mapping method names to lists of input tuples.
153+
First sample (index 0) is used for torch.export.export() to export the model.
154+
All samples are used as calibration dataset in PT2E Quantize stage,
155+
Optional when model is ExportedProgram (not needed).
121156
export_recipe: Contains the configuration for the export process
122157
name: Optional name for the export
123158
dynamic_shapes: Optional dynamic shape specifications
124159
constant_methods: Optional dictionary of constant methods
125160
artifact_dir: Optional directory to store artifacts
126161
generate_etrecord: Optional flag to generate an etrecord
127162
"""
163+
# Load model from file if string path provided
164+
if isinstance(model, str):
165+
model = torch.export.load(model)
166+
logging.info(f"Loaded ExportedProgram from {model}")
167+
168+
# Detect input model type to determine which stages to skip
169+
self._input_model_type = self._detect_model_type(model)
170+
128171
# Standardize model to dictionary format
129172
self._model = model if isinstance(model, dict) else {"forward": model}
130173

131-
# Standardize example_inputs to dictionary format
132-
self._example_inputs = (
133-
example_inputs
134-
if isinstance(example_inputs, dict)
135-
else {"forward": example_inputs}
136-
)
174+
# Validate and standardize example_inputs to dictionary format
175+
# example_inputs not required for ExportedProgram input
176+
if self._input_model_type == "ExportedProgram":
177+
self._example_inputs = example_inputs or {}
178+
if isinstance(self._example_inputs, list):
179+
self._example_inputs = {"forward": self._example_inputs}
180+
else:
181+
# For nn.Module and GraphModule, example_inputs are required
182+
if example_inputs is None:
183+
raise ValueError(
184+
f"example_inputs are required when model is {self._input_model_type}. "
185+
f"Only ExportedProgram inputs can omit example_inputs."
186+
)
187+
self._example_inputs = (
188+
example_inputs
189+
if isinstance(example_inputs, dict)
190+
else {"forward": example_inputs}
191+
)
137192

138193
# Standardize dynamic_shapes to dictionary format
139194
self._dynamic_shapes = {}
@@ -176,14 +231,64 @@ def __init__(
176231

177232
self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}
178233

234+
def _detect_model_type(
235+
self, model: Union[nn.Module, GraphModule, ExportedProgram, Dict]
236+
) -> str:
237+
"""
238+
Detect the type of input model.
239+
240+
Args:
241+
model: Input model in various formats
242+
243+
Returns:
244+
String indicating the model type: "nn.Module", "GraphModule", or "ExportedProgram"
245+
"""
246+
# Handle dict (multi-method) - check first value
247+
if isinstance(model, dict):
248+
first_value = next(iter(model.values()))
249+
return self._detect_model_type(first_value)
250+
251+
# Detect single model type
252+
if isinstance(model, ExportedProgram):
253+
return "ExportedProgram"
254+
elif isinstance(model, GraphModule):
255+
return "GraphModule"
256+
elif isinstance(model, nn.Module):
257+
return "nn.Module"
258+
else:
259+
raise TypeError(f"Unsupported model type: {type(model)}")
260+
179261
def _get_default_pipeline(self) -> List[StageType]:
180-
return [
181-
StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid
182-
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
183-
StageType.TORCH_EXPORT,
184-
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
185-
StageType.TO_EXECUTORCH,
186-
]
262+
"""
263+
Get default pipeline stages based on input model type.
264+
265+
Returns:
266+
List of stages appropriate for the input model type
267+
"""
268+
stages = []
269+
270+
# Add quantization stages only for eager nn.Module
271+
if self._input_model_type == "nn.Module":
272+
stages.extend(
273+
[
274+
StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid
275+
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
276+
]
277+
)
278+
279+
# Add torch export stage if not already exported
280+
if self._input_model_type != "ExportedProgram":
281+
stages.append(StageType.TORCH_EXPORT)
282+
283+
# Always include edge and executorch stages
284+
stages.extend(
285+
[
286+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
287+
StageType.TO_EXECUTORCH,
288+
]
289+
)
290+
291+
return stages
187292

188293
def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
189294
"""Build the stage registry from the given stages."""
@@ -259,6 +364,34 @@ def _validate_pipeline_sequence(
259364
if not stages:
260365
raise ValueError("Pipeline stages cannot be empty")
261366

367+
# Validate pipeline compatibility with input model type
368+
if self._input_model_type == "GraphModule":
369+
# GraphModule input should not run quantization stages
370+
incompatible_stages = {StageType.SOURCE_TRANSFORM, StageType.QUANTIZE}
371+
found_incompatible = set(stages) & incompatible_stages
372+
if found_incompatible:
373+
stage_names = ", ".join(s.name for s in found_incompatible)
374+
raise ValueError(
375+
f"Cannot run {stage_names} stage(s) with GraphModule input. "
376+
f"GraphModule is already quantized. "
377+
f"Remove {stage_names} from pipeline_stages or use nn.Module input."
378+
)
379+
elif self._input_model_type == "ExportedProgram":
380+
# ExportedProgram input should not run quantization or torch export stages
381+
incompatible_stages = {
382+
StageType.SOURCE_TRANSFORM,
383+
StageType.QUANTIZE,
384+
StageType.TORCH_EXPORT,
385+
}
386+
found_incompatible = set(stages) & incompatible_stages
387+
if found_incompatible:
388+
stage_names = ", ".join(s.name for s in found_incompatible)
389+
raise ValueError(
390+
f"Cannot run {stage_names} stage(s) with ExportedProgram input. "
391+
f"ExportedProgram is already exported. "
392+
f"Remove {stage_names} from pipeline_stages or use nn.Module/GraphModule input."
393+
)
394+
262395
# Validate that the first stage can start a pipeline
263396
first_stage = stages[0]
264397
first_stage_instance = self._stage_registry.get(first_stage)
@@ -436,6 +569,9 @@ def run_method(
436569
Raises:
437570
RuntimeError: If the method cannot be loaded
438571
"""
572+
# Lazy import to avoid forcing portable_lib dependency at module load time
573+
from executorch.runtime import Runtime, Verification
574+
439575
et_runtime = Runtime.get()
440576
program = et_runtime.load_program(
441577
self.get_pte_buffer(), verification=Verification.Minimal

export/stages.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def __init__(
179179
) = None,
180180
compile_config: Optional[Any] = None,
181181
) -> None:
182+
super().__init__()
182183
self._partitioners = partitioners
183184
self._transform_passes = transform_passes
184185
self._compile_config = compile_config
@@ -206,7 +207,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:
206207

207208
@property
208209
def can_start_pipeline(self) -> bool:
209-
return False
210+
return True
210211

211212
def run(self, artifact: PipelineArtifact) -> None:
212213
"""
@@ -271,7 +272,10 @@ def stage_type(self) -> str:
271272

272273
@property
273274
def valid_predecessor_stages(self) -> List["StageType"]:
274-
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND]
275+
return [
276+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
277+
StageType.TO_BACKEND,
278+
]
275279

276280
@property
277281
def can_start_pipeline(self) -> bool:
@@ -468,7 +472,7 @@ def valid_predecessor_stages(self) -> List["StageType"]:
468472

469473
@property
470474
def can_start_pipeline(self) -> bool:
471-
return False
475+
return True
472476

473477
def run(self, artifact: PipelineArtifact) -> None:
474478
"""

0 commit comments

Comments
 (0)