1313from executorch .exir .program import ExecutorchProgramManager
1414from executorch .exir .schema import Program
1515from executorch .extension .export_util .utils import save_pte_program
16- from executorch .runtime import Runtime , Verification
1716from tabulate import tabulate
1817from torch import nn
18+ from torch .export import ExportedProgram
19+ from torch .fx import GraphModule
1920
2021from .recipe import ExportRecipe , LoweringRecipe , QuantizationRecipe
2122from .stages import (
3637 "This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
3738)
3839def export (
39- model : Union [nn .Module , Dict [str , nn .Module ]],
40- example_inputs : Union [
41- List [tuple [torch .Tensor , ...]], Dict [str , List [tuple [torch .Tensor , ...]]]
40+ model : Union [
41+ nn .Module ,
42+ Dict [str , nn .Module ],
43+ GraphModule ,
44+ Dict [str , GraphModule ],
45+ ExportedProgram ,
46+ Dict [str , ExportedProgram ],
47+ str ,
4248 ],
43- export_recipe : ExportRecipe ,
49+ example_inputs : Optional [
50+ Union [
51+ List [tuple [torch .Tensor , ...]],
52+ Dict [str , List [tuple [torch .Tensor , ...]]],
53+ ]
54+ ] = None ,
55+ export_recipe : ExportRecipe = None ,
4456 name : Optional [str ] = None ,
4557 dynamic_shapes : Optional [Union [Any , Dict [str , Any ]]] = None ,
4658 constant_methods : Optional [Union [Dict [str , Callable ]]] = None ,
@@ -54,10 +66,16 @@ def export(
5466 optionally run the export process in one step.
5567
5668 Args:
57- model: The PyTorch model(s) to export, either a single model or a dictionary
58- mapping method names to models
69+ model: The PyTorch model(s) to export. Can be:
70+ - nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
71+ - GraphModule or Dict[str, GraphModule]: Quantized model(s) (e.g., from prepare/convert)
72+ - ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
73+ - str: Path to load an ExportedProgram from disk
5974 example_inputs: Example inputs for the model(s), either a list of input tuples
60- or a dictionary mapping method names to lists of input tuples
75+ or a dictionary mapping method names to lists of input tuples.
76+ First sample (index 0) is used for torch.export.export() to export the model.
77+ All samples are used as calibration dataset in PT2E Quantize stage.
78+ Optional when model is ExportedProgram (not needed).
6179 export_recipe: Contains the configuration for the export process
6280 name: Optional name for the export
6381 dynamic_shapes: Optional dynamic shape specifications
@@ -99,11 +117,22 @@ class ExportSession:
99117
100118 def __init__ (
101119 self ,
102- model : Union [nn .Module , Dict [str , nn .Module ]],
103- example_inputs : Union [
104- List [tuple [torch .Tensor , ...]], Dict [str , List [tuple [torch .Tensor , ...]]]
120+ model : Union [
121+ nn .Module ,
122+ Dict [str , nn .Module ],
123+ GraphModule ,
124+ Dict [str , GraphModule ],
125+ ExportedProgram ,
126+ Dict [str , ExportedProgram ],
127+ str ,
105128 ],
106- export_recipe : ExportRecipe ,
129+ example_inputs : Optional [
130+ Union [
131+ List [tuple [torch .Tensor , ...]],
132+ Dict [str , List [tuple [torch .Tensor , ...]]],
133+ ]
134+ ] = None ,
135+ export_recipe : ExportRecipe = None ,
107136 name : Optional [str ] = None ,
108137 dynamic_shapes : Optional [Union [Any , Dict [str , Any ]]] = None ,
109138 constant_methods : Optional [Union [Dict [str , Callable ]]] = None ,
@@ -114,26 +143,52 @@ def __init__(
114143 Initialize the ExportSession with model, inputs, and recipe.
115144
116145 Args:
117- model: The PyTorch model(s) to export, either a single model or a dictionary
118- mapping method names to models
146+ model: The PyTorch model(s) to export. Can be:
147+ - nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s)
148+ - GraphModule or Dict[str, GraphModule]: Quantized model(s)
149+ - ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s)
150+ - str: Path to load an ExportedProgram from disk
119151 example_inputs: Example inputs for the model(s), either a list of input tuples
120- or a dictionary mapping method names to lists of input tuples
152+ or a dictionary mapping method names to lists of input tuples.
153+ First sample (index 0) is used for torch.export.export() to export the model.
154+ All samples are used as calibration dataset in PT2E Quantize stage,
155+ Optional when model is ExportedProgram (not needed).
121156 export_recipe: Contains the configuration for the export process
122157 name: Optional name for the export
123158 dynamic_shapes: Optional dynamic shape specifications
124159 constant_methods: Optional dictionary of constant methods
125160 artifact_dir: Optional directory to store artifacts
126161 generate_etrecord: Optional flag to generate an etrecord
127162 """
163+ # Load model from file if string path provided
164+ if isinstance (model , str ):
165+ model = torch .export .load (model )
166+ logging .info (f"Loaded ExportedProgram from { model } " )
167+
168+ # Detect input model type to determine which stages to skip
169+ self ._input_model_type = self ._detect_model_type (model )
170+
128171 # Standardize model to dictionary format
129172 self ._model = model if isinstance (model , dict ) else {"forward" : model }
130173
131- # Standardize example_inputs to dictionary format
132- self ._example_inputs = (
133- example_inputs
134- if isinstance (example_inputs , dict )
135- else {"forward" : example_inputs }
136- )
174+ # Validate and standardize example_inputs to dictionary format
175+ # example_inputs not required for ExportedProgram input
176+ if self ._input_model_type == "ExportedProgram" :
177+ self ._example_inputs = example_inputs or {}
178+ if isinstance (self ._example_inputs , list ):
179+ self ._example_inputs = {"forward" : self ._example_inputs }
180+ else :
181+ # For nn.Module and GraphModule, example_inputs are required
182+ if example_inputs is None :
183+ raise ValueError (
184+ f"example_inputs are required when model is { self ._input_model_type } . "
185+ f"Only ExportedProgram inputs can omit example_inputs."
186+ )
187+ self ._example_inputs = (
188+ example_inputs
189+ if isinstance (example_inputs , dict )
190+ else {"forward" : example_inputs }
191+ )
137192
138193 # Standardize dynamic_shapes to dictionary format
139194 self ._dynamic_shapes = {}
@@ -176,14 +231,64 @@ def __init__(
176231
177232 self ._stage_to_artifacts : Dict [StageType , PipelineArtifact ] = {}
178233
234+ def _detect_model_type (
235+ self , model : Union [nn .Module , GraphModule , ExportedProgram , Dict ]
236+ ) -> str :
237+ """
238+ Detect the type of input model.
239+
240+ Args:
241+ model: Input model in various formats
242+
243+ Returns:
244+ String indicating the model type: "nn.Module", "GraphModule", or "ExportedProgram"
245+ """
246+ # Handle dict (multi-method) - check first value
247+ if isinstance (model , dict ):
248+ first_value = next (iter (model .values ()))
249+ return self ._detect_model_type (first_value )
250+
251+ # Detect single model type
252+ if isinstance (model , ExportedProgram ):
253+ return "ExportedProgram"
254+ elif isinstance (model , GraphModule ):
255+ return "GraphModule"
256+ elif isinstance (model , nn .Module ):
257+ return "nn.Module"
258+ else :
259+ raise TypeError (f"Unsupported model type: { type (model )} " )
260+
179261 def _get_default_pipeline (self ) -> List [StageType ]:
180- return [
181- StageType .SOURCE_TRANSFORM , # Optional stage, returns original model if quant recipe is invalid
182- StageType .QUANTIZE , # Optional stage, returns original model if quant recipe is invalid
183- StageType .TORCH_EXPORT ,
184- StageType .TO_EDGE_TRANSFORM_AND_LOWER ,
185- StageType .TO_EXECUTORCH ,
186- ]
262+ """
263+ Get default pipeline stages based on input model type.
264+
265+ Returns:
266+ List of stages appropriate for the input model type
267+ """
268+ stages = []
269+
270+ # Add quantization stages only for eager nn.Module
271+ if self ._input_model_type == "nn.Module" :
272+ stages .extend (
273+ [
274+ StageType .SOURCE_TRANSFORM , # Optional stage, returns original model if quant recipe is invalid
275+ StageType .QUANTIZE , # Optional stage, returns original model if quant recipe is invalid
276+ ]
277+ )
278+
279+ # Add torch export stage if not already exported
280+ if self ._input_model_type != "ExportedProgram" :
281+ stages .append (StageType .TORCH_EXPORT )
282+
283+ # Always include edge and executorch stages
284+ stages .extend (
285+ [
286+ StageType .TO_EDGE_TRANSFORM_AND_LOWER ,
287+ StageType .TO_EXECUTORCH ,
288+ ]
289+ )
290+
291+ return stages
187292
188293 def _build_stages (self , stages : List [StageType ]) -> Dict [StageType , Stage ]:
189294 """Build the stage registry from the given stages."""
@@ -259,6 +364,34 @@ def _validate_pipeline_sequence(
259364 if not stages :
260365 raise ValueError ("Pipeline stages cannot be empty" )
261366
367+ # Validate pipeline compatibility with input model type
368+ if self ._input_model_type == "GraphModule" :
369+ # GraphModule input should not run quantization stages
370+ incompatible_stages = {StageType .SOURCE_TRANSFORM , StageType .QUANTIZE }
371+ found_incompatible = set (stages ) & incompatible_stages
372+ if found_incompatible :
373+ stage_names = ", " .join (s .name for s in found_incompatible )
374+ raise ValueError (
375+ f"Cannot run { stage_names } stage(s) with GraphModule input. "
376+ f"GraphModule is already quantized. "
377+ f"Remove { stage_names } from pipeline_stages or use nn.Module input."
378+ )
379+ elif self ._input_model_type == "ExportedProgram" :
380+ # ExportedProgram input should not run quantization or torch export stages
381+ incompatible_stages = {
382+ StageType .SOURCE_TRANSFORM ,
383+ StageType .QUANTIZE ,
384+ StageType .TORCH_EXPORT ,
385+ }
386+ found_incompatible = set (stages ) & incompatible_stages
387+ if found_incompatible :
388+ stage_names = ", " .join (s .name for s in found_incompatible )
389+ raise ValueError (
390+ f"Cannot run { stage_names } stage(s) with ExportedProgram input. "
391+ f"ExportedProgram is already exported. "
392+ f"Remove { stage_names } from pipeline_stages or use nn.Module/GraphModule input."
393+ )
394+
262395 # Validate that the first stage can start a pipeline
263396 first_stage = stages [0 ]
264397 first_stage_instance = self ._stage_registry .get (first_stage )
@@ -436,6 +569,9 @@ def run_method(
436569 Raises:
437570 RuntimeError: If the method cannot be loaded
438571 """
572+ # Lazy import to avoid forcing portable_lib dependency at module load time
573+ from executorch .runtime import Runtime , Verification
574+
439575 et_runtime = Runtime .get ()
440576 program = et_runtime .load_program (
441577 self .get_pte_buffer (), verification = Verification .Minimal
0 commit comments