diff --git a/export/export.py b/export/export.py index 1e9cdbde7c0..226a3a06eda 100644 --- a/export/export.py +++ b/export/export.py @@ -13,9 +13,10 @@ from executorch.exir.program import ExecutorchProgramManager from executorch.exir.schema import Program from executorch.extension.export_util.utils import save_pte_program -from executorch.runtime import Runtime, Verification from tabulate import tabulate from torch import nn +from torch.export import ExportedProgram +from torch.fx import GraphModule from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe from .stages import ( @@ -36,11 +37,22 @@ "This API and all of its related functionality such as ExportSession and ExportRecipe are experimental." ) def export( - model: Union[nn.Module, Dict[str, nn.Module]], - example_inputs: Union[ - List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]] + model: Union[ + nn.Module, + Dict[str, nn.Module], + GraphModule, + Dict[str, GraphModule], + ExportedProgram, + Dict[str, ExportedProgram], + str, ], - export_recipe: ExportRecipe, + example_inputs: Optional[ + Union[ + List[tuple[torch.Tensor, ...]], + Dict[str, List[tuple[torch.Tensor, ...]]], + ] + ] = None, + export_recipe: ExportRecipe = None, name: Optional[str] = None, dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None, constant_methods: Optional[Union[Dict[str, Callable]]] = None, @@ -54,10 +66,16 @@ def export( optionally run the export process in one step. Args: - model: The PyTorch model(s) to export, either a single model or a dictionary - mapping method names to models + model: The PyTorch model(s) to export. Can be: + - nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s) + - GraphModule or Dict[str, GraphModule]: Quantized model(s) (e.g., from prepare/convert) + - ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s) + - str: Path to load an ExportedProgram from disk example_inputs: Example inputs for the model(s), either a list of input tuples - or a dictionary mapping method names to lists of input tuples + or a dictionary mapping method names to lists of input tuples. + First sample (index 0) is used for torch.export.export() to export the model. + All samples are used as calibration dataset in PT2E Quantize stage. + Optional when model is ExportedProgram (not needed). export_recipe: Contains the configuration for the export process name: Optional name for the export dynamic_shapes: Optional dynamic shape specifications @@ -99,11 +117,22 @@ class ExportSession: def __init__( self, - model: Union[nn.Module, Dict[str, nn.Module]], - example_inputs: Union[ - List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]] + model: Union[ + nn.Module, + Dict[str, nn.Module], + GraphModule, + Dict[str, GraphModule], + ExportedProgram, + Dict[str, ExportedProgram], + str, ], - export_recipe: ExportRecipe, + example_inputs: Optional[ + Union[ + List[tuple[torch.Tensor, ...]], + Dict[str, List[tuple[torch.Tensor, ...]]], + ] + ] = None, + export_recipe: ExportRecipe = None, name: Optional[str] = None, dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None, constant_methods: Optional[Union[Dict[str, Callable]]] = None, @@ -114,10 +143,16 @@ def __init__( Initialize the ExportSession with model, inputs, and recipe. Args: - model: The PyTorch model(s) to export, either a single model or a dictionary - mapping method names to models + model: The PyTorch model(s) to export. Can be: + - nn.Module or Dict[str, nn.Module]: Eager PyTorch model(s) + - GraphModule or Dict[str, GraphModule]: Quantized model(s) + - ExportedProgram or Dict[str, ExportedProgram]: Already exported model(s) + - str: Path to load an ExportedProgram from disk example_inputs: Example inputs for the model(s), either a list of input tuples - or a dictionary mapping method names to lists of input tuples + or a dictionary mapping method names to lists of input tuples. + First sample (index 0) is used for torch.export.export() to export the model. + All samples are used as calibration dataset in PT2E Quantize stage, + Optional when model is ExportedProgram (not needed). export_recipe: Contains the configuration for the export process name: Optional name for the export dynamic_shapes: Optional dynamic shape specifications @@ -125,15 +160,35 @@ def __init__( artifact_dir: Optional directory to store artifacts generate_etrecord: Optional flag to generate an etrecord """ + # Load model from file if string path provided + if isinstance(model, str): + model = torch.export.load(model) + logging.info(f"Loaded ExportedProgram from {model}") + + # Detect input model type to determine which stages to skip + self._input_model_type = self._detect_model_type(model) + # Standardize model to dictionary format self._model = model if isinstance(model, dict) else {"forward": model} - # Standardize example_inputs to dictionary format - self._example_inputs = ( - example_inputs - if isinstance(example_inputs, dict) - else {"forward": example_inputs} - ) + # Validate and standardize example_inputs to dictionary format + # example_inputs not required for ExportedProgram input + if self._input_model_type == "ExportedProgram": + self._example_inputs = example_inputs or {} + if isinstance(self._example_inputs, list): + self._example_inputs = {"forward": self._example_inputs} + else: + # For nn.Module and GraphModule, example_inputs are required + if example_inputs is None: + raise ValueError( + f"example_inputs are required when model is {self._input_model_type}. " + f"Only ExportedProgram inputs can omit example_inputs." + ) + self._example_inputs = ( + example_inputs + if isinstance(example_inputs, dict) + else {"forward": example_inputs} + ) # Standardize dynamic_shapes to dictionary format self._dynamic_shapes = {} @@ -176,14 +231,64 @@ def __init__( self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {} + def _detect_model_type( + self, model: Union[nn.Module, GraphModule, ExportedProgram, Dict] + ) -> str: + """ + Detect the type of input model. + + Args: + model: Input model in various formats + + Returns: + String indicating the model type: "nn.Module", "GraphModule", or "ExportedProgram" + """ + # Handle dict (multi-method) - check first value + if isinstance(model, dict): + first_value = next(iter(model.values())) + return self._detect_model_type(first_value) + + # Detect single model type + if isinstance(model, ExportedProgram): + return "ExportedProgram" + elif isinstance(model, GraphModule): + return "GraphModule" + elif isinstance(model, nn.Module): + return "nn.Module" + else: + raise TypeError(f"Unsupported model type: {type(model)}") + def _get_default_pipeline(self) -> List[StageType]: - return [ - StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid - StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid - StageType.TORCH_EXPORT, - StageType.TO_EDGE_TRANSFORM_AND_LOWER, - StageType.TO_EXECUTORCH, - ] + """ + Get default pipeline stages based on input model type. + + Returns: + List of stages appropriate for the input model type + """ + stages = [] + + # Add quantization stages only for eager nn.Module + if self._input_model_type == "nn.Module": + stages.extend( + [ + StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid + StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid + ] + ) + + # Add torch export stage if not already exported + if self._input_model_type != "ExportedProgram": + stages.append(StageType.TORCH_EXPORT) + + # Always include edge and executorch stages + stages.extend( + [ + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + return stages def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: """Build the stage registry from the given stages.""" @@ -259,6 +364,34 @@ def _validate_pipeline_sequence( if not stages: raise ValueError("Pipeline stages cannot be empty") + # Validate pipeline compatibility with input model type + if self._input_model_type == "GraphModule": + # GraphModule input should not run quantization stages + incompatible_stages = {StageType.SOURCE_TRANSFORM, StageType.QUANTIZE} + found_incompatible = set(stages) & incompatible_stages + if found_incompatible: + stage_names = ", ".join(s.name for s in found_incompatible) + raise ValueError( + f"Cannot run {stage_names} stage(s) with GraphModule input. " + f"GraphModule is already quantized. " + f"Remove {stage_names} from pipeline_stages or use nn.Module input." + ) + elif self._input_model_type == "ExportedProgram": + # ExportedProgram input should not run quantization or torch export stages + incompatible_stages = { + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + } + found_incompatible = set(stages) & incompatible_stages + if found_incompatible: + stage_names = ", ".join(s.name for s in found_incompatible) + raise ValueError( + f"Cannot run {stage_names} stage(s) with ExportedProgram input. " + f"ExportedProgram is already exported. " + f"Remove {stage_names} from pipeline_stages or use nn.Module/GraphModule input." + ) + # Validate that the first stage can start a pipeline first_stage = stages[0] first_stage_instance = self._stage_registry.get(first_stage) @@ -436,6 +569,9 @@ def run_method( Raises: RuntimeError: If the method cannot be loaded """ + # Lazy import to avoid forcing portable_lib dependency at module load time + from executorch.runtime import Runtime, Verification + et_runtime = Runtime.get() program = et_runtime.load_program( self.get_pte_buffer(), verification=Verification.Minimal diff --git a/export/stages.py b/export/stages.py index 3be801c6a14..051f6c64ec0 100644 --- a/export/stages.py +++ b/export/stages.py @@ -179,6 +179,7 @@ def __init__( ) = None, compile_config: Optional[Any] = None, ) -> None: + super().__init__() self._partitioners = partitioners self._transform_passes = transform_passes self._compile_config = compile_config @@ -206,7 +207,7 @@ def valid_predecessor_stages(self) -> List["StageType"]: @property def can_start_pipeline(self) -> bool: - return False + return True def run(self, artifact: PipelineArtifact) -> None: """ @@ -271,7 +272,10 @@ def stage_type(self) -> str: @property def valid_predecessor_stages(self) -> List["StageType"]: - return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND] + return [ + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_BACKEND, + ] @property def can_start_pipeline(self) -> bool: @@ -468,7 +472,7 @@ def valid_predecessor_stages(self) -> List["StageType"]: @property def can_start_pipeline(self) -> bool: - return False + return True def run(self, artifact: PipelineArtifact) -> None: """ diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index fcec1b7a59a..d28c369eaa6 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -61,10 +61,11 @@ def _create_mock_stage(self, stage_type: StageType) -> Mock: mock_stage.can_start_pipeline = True elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: mock_stage.valid_predecessor_stages = [StageType.TORCH_EXPORT] - mock_stage.can_start_pipeline = False + mock_stage.can_start_pipeline = True elif stage_type == StageType.TO_EXECUTORCH: mock_stage.valid_predecessor_stages = [ - StageType.TO_EDGE_TRANSFORM_AND_LOWER + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_BACKEND, ] mock_stage.can_start_pipeline = True else: @@ -279,7 +280,7 @@ def test_valid_pipeline_sequences(self) -> None: StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH, ], - # Skip source transform and tart with quantize + # Skip source transform and start with quantize [ StageType.QUANTIZE, StageType.TORCH_EXPORT, @@ -292,6 +293,17 @@ def test_valid_pipeline_sequences(self) -> None: StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH, ], + # Start with edge transform and lower (ExportedProgram input) + [ + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + # Start with to_edge and to_backend + [ + StageType.TO_EDGE, + StageType.TO_BACKEND, + StageType.TO_EXECUTORCH, + ], ] for i, stages in enumerate(valid_sequences): @@ -306,9 +318,11 @@ def test_valid_pipeline_sequences(self) -> None: def test_invalid_pipeline_start_stages(self) -> None: """Test stages that cannot start a pipeline.""" invalid_stage_sequence = [ - # Edge stage cannot start pipeline - [StageType.TO_EDGE_TRANSFORM_AND_LOWER], - [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH], + # Executorch stage cannot start pipeline (requires edge stage first) + [StageType.TO_EXECUTORCH], + # Backend stage cannot start pipeline (requires TO_EDGE first) + [StageType.TO_BACKEND], + [StageType.TO_BACKEND, StageType.TO_EXECUTORCH], ] for i, stages in enumerate(invalid_stage_sequence): @@ -485,3 +499,321 @@ def test_pipeline_building_with_all_recipes(self) -> None: StageType.TO_EXECUTORCH, ] self.assertListEqual(list(registered_stages.keys()), expected_types) + + +class TestExportSessionExtendedInputTypes(unittest.TestCase): + """Test extended input type support (GraphModule, ExportedProgram, etc.)""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = (torch.randn(2, 10),) + self.recipe = ExportRecipe(name="test") + + def test_nn_module_input_type_detection(self) -> None: + """Test that nn.Module input is detected correctly.""" + session = ExportSession( + model=self.model, + example_inputs=[self.example_inputs], + export_recipe=self.recipe, + ) + + self.assertEqual(session._input_model_type, "nn.Module") + + # Verify default pipeline includes quantization stages + pipeline = session._get_default_pipeline() + self.assertIn(StageType.SOURCE_TRANSFORM, pipeline) + self.assertIn(StageType.QUANTIZE, pipeline) + self.assertIn(StageType.TORCH_EXPORT, pipeline) + self.assertIn(StageType.TO_EDGE_TRANSFORM_AND_LOWER, pipeline) + self.assertIn(StageType.TO_EXECUTORCH, pipeline) + + def test_graph_module_input_type_detection(self) -> None: + """Test that GraphModule input is detected correctly.""" + # Create a GraphModule using fx.symbolic_trace + graph_module = torch.fx.symbolic_trace(self.model) + + session = ExportSession( + model=graph_module, + example_inputs=[self.example_inputs], + export_recipe=self.recipe, + ) + + self.assertEqual(session._input_model_type, "GraphModule") + + # Verify default pipeline skips quantization stages + pipeline = session._get_default_pipeline() + self.assertNotIn(StageType.SOURCE_TRANSFORM, pipeline) + self.assertNotIn(StageType.QUANTIZE, pipeline) + self.assertIn(StageType.TORCH_EXPORT, pipeline) + self.assertIn(StageType.TO_EDGE_TRANSFORM_AND_LOWER, pipeline) + self.assertIn(StageType.TO_EXECUTORCH, pipeline) + + def test_exported_program_input_type_detection(self) -> None: + """Test that ExportedProgram input is detected correctly.""" + # Create an ExportedProgram + exported_program = torch.export.export(self.model, self.example_inputs) + + # ExportedProgram should not require example_inputs + session = ExportSession( + model=exported_program, + export_recipe=self.recipe, + ) + + self.assertEqual(session._input_model_type, "ExportedProgram") + + # Verify default pipeline skips quantization and torch export stages + pipeline = session._get_default_pipeline() + self.assertNotIn(StageType.SOURCE_TRANSFORM, pipeline) + self.assertNotIn(StageType.QUANTIZE, pipeline) + self.assertNotIn(StageType.TORCH_EXPORT, pipeline) + self.assertIn(StageType.TO_EDGE_TRANSFORM_AND_LOWER, pipeline) + self.assertIn(StageType.TO_EXECUTORCH, pipeline) + + def test_dict_nn_module_input_type_detection(self) -> None: + """Test that Dict[str, nn.Module] input is detected correctly.""" + model_dict = { + "forward": self.model, + "method2": SimpleTestModel(), + } + inputs_dict = { + "forward": [self.example_inputs], + "method2": [(torch.randn(1, 10),)], + } + + session = ExportSession( + model=model_dict, + example_inputs=inputs_dict, + export_recipe=self.recipe, + ) + + # Should detect type based on first value + self.assertEqual(session._input_model_type, "nn.Module") + + def test_dict_graph_module_input_type_detection(self) -> None: + """Test that Dict[str, GraphModule] input is detected correctly.""" + graph_module1 = torch.fx.symbolic_trace(self.model) + graph_module2 = torch.fx.symbolic_trace(SimpleTestModel()) + + model_dict = { + "forward": graph_module1, + "method2": graph_module2, + } + inputs_dict = { + "forward": [self.example_inputs], + "method2": [(torch.randn(1, 10),)], + } + + session = ExportSession( + model=model_dict, + example_inputs=inputs_dict, + export_recipe=self.recipe, + ) + + # Should detect GraphModule type + self.assertEqual(session._input_model_type, "GraphModule") + + # Verify pipeline skips quantization + pipeline = session._get_default_pipeline() + self.assertNotIn(StageType.QUANTIZE, pipeline) + + def test_dict_exported_program_input_type_detection(self) -> None: + """Test that Dict[str, ExportedProgram] input is detected correctly.""" + ep1 = torch.export.export(self.model, self.example_inputs) + ep2 = torch.export.export(SimpleTestModel(), (torch.randn(1, 10),)) + + model_dict = { + "forward": ep1, + "method2": ep2, + } + + session = ExportSession( + model=model_dict, + export_recipe=self.recipe, + ) + + # Should detect ExportedProgram type + self.assertEqual(session._input_model_type, "ExportedProgram") + + # Verify pipeline skips export stages + pipeline = session._get_default_pipeline() + self.assertNotIn(StageType.TORCH_EXPORT, pipeline) + + def test_example_inputs_required_for_nn_module(self) -> None: + """Test that example_inputs are required for nn.Module.""" + with self.assertRaises(ValueError) as cm: + ExportSession( + model=self.model, + export_recipe=self.recipe, + ) + self.assertIn("example_inputs are required", str(cm.exception)) + self.assertIn("nn.Module", str(cm.exception)) + + def test_example_inputs_required_for_graph_module(self) -> None: + """Test that example_inputs are required for GraphModule.""" + graph_module = torch.fx.symbolic_trace(self.model) + + with self.assertRaises(ValueError) as cm: + ExportSession( + model=graph_module, + export_recipe=self.recipe, + ) + self.assertIn("example_inputs are required", str(cm.exception)) + self.assertIn("GraphModule", str(cm.exception)) + + def test_example_inputs_optional_for_exported_program(self) -> None: + """Test that example_inputs are optional for ExportedProgram.""" + exported_program = torch.export.export(self.model, self.example_inputs) + + # Should not raise + session = ExportSession( + model=exported_program, + export_recipe=self.recipe, + ) + + self.assertEqual(session._input_model_type, "ExportedProgram") + + def test_validation_graph_module_cannot_run_quantization(self) -> None: + """Test that GraphModule input cannot run quantization stages.""" + graph_module = torch.fx.symbolic_trace(self.model) + + # Try to force quantization stages + recipe = ExportRecipe( + pipeline_stages=[ + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=graph_module, + example_inputs=[self.example_inputs], + export_recipe=recipe, + ) + + with self.assertRaises(ValueError) as cm: + session.export() + self.assertIn("Cannot run", str(cm.exception)) + self.assertIn("stage(s)", str(cm.exception)) + self.assertIn("QUANTIZE", str(cm.exception)) + self.assertIn("GraphModule", str(cm.exception)) + + def test_validation_graph_module_cannot_run_source_transform(self) -> None: + """Test that GraphModule input cannot run source transform stage.""" + graph_module = torch.fx.symbolic_trace(self.model) + + # Try to force source transform stage + recipe = ExportRecipe( + pipeline_stages=[ + StageType.SOURCE_TRANSFORM, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=graph_module, + example_inputs=[self.example_inputs], + export_recipe=recipe, + ) + + with self.assertRaises(ValueError) as cm: + session.export() + self.assertIn("Cannot run", str(cm.exception)) + self.assertIn("stage(s)", str(cm.exception)) + self.assertIn("SOURCE_TRANSFORM", str(cm.exception)) + self.assertIn("GraphModule", str(cm.exception)) + + def test_validation_exported_program_cannot_run_torch_export(self) -> None: + """Test that ExportedProgram input cannot run torch export stage.""" + exported_program = torch.export.export(self.model, self.example_inputs) + + # Try to force torch export stage + recipe = ExportRecipe( + pipeline_stages=[ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=exported_program, + export_recipe=recipe, + ) + + with self.assertRaises(ValueError) as cm: + session.export() + self.assertIn("Cannot run", str(cm.exception)) + self.assertIn("stage(s)", str(cm.exception)) + self.assertIn("TORCH_EXPORT", str(cm.exception)) + self.assertIn("ExportedProgram", str(cm.exception)) + + def test_validation_exported_program_cannot_run_quantization(self) -> None: + """Test that ExportedProgram input cannot run quantization stages.""" + exported_program = torch.export.export(self.model, self.example_inputs) + + # Try to force quantization stages + recipe = ExportRecipe( + pipeline_stages=[ + StageType.QUANTIZE, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=exported_program, + export_recipe=recipe, + ) + + with self.assertRaises(ValueError) as cm: + session.export() + self.assertIn("Cannot run", str(cm.exception)) + self.assertIn("stage(s)", str(cm.exception)) + self.assertIn("QUANTIZE", str(cm.exception)) + self.assertIn("ExportedProgram", str(cm.exception)) + + def test_graph_module_valid_pipeline(self) -> None: + """Test valid pipeline for GraphModule input.""" + graph_module = torch.fx.symbolic_trace(self.model) + + # Valid pipeline starting from torch export + recipe = ExportRecipe( + pipeline_stages=[ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=graph_module, + example_inputs=[self.example_inputs], + export_recipe=recipe, + ) + + # Should not raise during validation + session._validate_pipeline_sequence(recipe.pipeline_stages) + + def test_exported_program_valid_pipeline(self) -> None: + """Test valid pipeline for ExportedProgram input.""" + exported_program = torch.export.export(self.model, self.example_inputs) + + # Valid pipeline starting from edge stages + recipe = ExportRecipe( + pipeline_stages=[ + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + ) + + session = ExportSession( + model=exported_program, + export_recipe=recipe, + ) + + # Should not raise during validation + session._validate_pipeline_sequence(recipe.pipeline_stages)