|
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.devtools.backend_debug import get_delegation_info |
16 | | -from executorch.exir import EdgeCompileConfig, ExportedProgram |
| 16 | +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram |
17 | 17 | from executorch.exir.backend.backend_api import validation_disabled |
| 18 | +from executorch.exir.pass_manager import PassManager |
18 | 19 | from executorch.exir.program import to_edge, to_edge_transform_and_lower |
19 | 20 | from executorch.export.recipe import LoweringRecipe, QuantizationRecipe |
20 | 21 | from executorch.export.types import StageType |
@@ -175,7 +176,7 @@ def __init__( |
175 | 176 | self, |
176 | 177 | partitioners: Optional[List[Any]] = None, |
177 | 178 | transform_passes: ( |
178 | | - None | List[Callable[[str, ExportedProgram], List[PassType]]] |
| 179 | + None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]] |
179 | 180 | ) = None, |
180 | 181 | compile_config: Optional[Any] = None, |
181 | 182 | ) -> None: |
@@ -217,28 +218,33 @@ def run(self, artifact: PipelineArtifact) -> None: |
217 | 218 | constant_methods = artifact.get_context("constant_methods") |
218 | 219 | generate_etrecord = artifact.get_context("generate_etrecord", False) |
219 | 220 |
|
220 | | - # per method transform passes |
| 221 | + # Detect if any callable returns PassManager |
| 222 | + pass_manager = None |
221 | 223 | transform_passes = defaultdict(list) |
222 | 224 | for method_name, ep in exported_programs.items(): |
223 | 225 | # Resolve transform passes from callable |
224 | | - for pass_ in self._transform_passes or []: |
225 | | - if not callable(pass_): |
| 226 | + for pass_callable in self._transform_passes or []: |
| 227 | + if not callable(pass_callable): |
226 | 228 | raise ValueError( |
227 | | - "Transform passes must be a callable that resolves to a list of passes" |
| 229 | + "Transform passes must be a callable that resolves to passes" |
228 | 230 | ) |
229 | | - passes = pass_(method_name, ep) |
230 | | - if isinstance(passes, list): |
231 | | - transform_passes[method_name].extend(passes) |
| 231 | + passes = pass_callable(method_name, ep) |
| 232 | + if isinstance(passes, PassManager): |
| 233 | + pass_manager = passes |
| 234 | + break |
232 | 235 | else: |
233 | | - raise ValueError( |
234 | | - "Transform passes must be a callable that resolves to a list of passes" |
235 | | - ) |
| 236 | + transform_passes[method_name].extend(passes) |
| 237 | + if pass_manager: |
| 238 | + break |
| 239 | + |
| 240 | + # Use PassManager directly if found, otherwise use dict |
| 241 | + final_passes = pass_manager if pass_manager else transform_passes |
236 | 242 |
|
237 | 243 | with validation_disabled(): |
238 | 244 | edge_program_manager = to_edge_transform_and_lower( |
239 | 245 | exported_programs, |
240 | 246 | partitioner=self._partitioners, |
241 | | - transform_passes=transform_passes, |
| 247 | + transform_passes=final_passes, |
242 | 248 | constant_methods=constant_methods, |
243 | 249 | compile_config=self._compile_config, |
244 | 250 | generate_etrecord=generate_etrecord, |
@@ -275,6 +281,7 @@ def valid_predecessor_stages(self) -> List["StageType"]: |
275 | 281 | return [ |
276 | 282 | StageType.TO_EDGE_TRANSFORM_AND_LOWER, |
277 | 283 | StageType.TO_BACKEND, |
| 284 | + StageType.EDGE_PROGRAM_MANAGER_TRANSFORM, # Added for server model generation (skipping TO_BACKEND) |
278 | 285 | ] |
279 | 286 |
|
280 | 287 | @property |
@@ -495,76 +502,166 @@ def run(self, artifact: PipelineArtifact) -> None: |
495 | 502 | self._artifact = artifact.copy_with_new_data(edge_program_manager) |
496 | 503 |
|
497 | 504 |
|
498 | | -class ToBackendStage(Stage): |
| 505 | +class EdgeProgramManagerTransformStage(Stage): |
499 | 506 | """ |
500 | | - Stage: Apply transformations and partitioning to EdgeProgramManager. |
| 507 | + Stage: Apply transformation passes that require EdgeProgramManager. |
| 508 | +
|
| 509 | + This stage enables dynamic pass generation where passes need access to the |
| 510 | + EdgeProgramManager instance. Passes are applied sequentially, allowing |
| 511 | + to control order and dependencies between pass groups. |
501 | 512 | """ |
502 | 513 |
|
503 | 514 | def __init__( |
504 | 515 | self, |
505 | | - partitioners: Optional[List[Any]] = None, |
506 | | - transform_passes: ( |
507 | | - None | List[Callable[[str, ExportedProgram], List[PassType]]] |
| 516 | + edge_transform_passes: ( |
| 517 | + None | List[Callable[[str, ExportedProgram], List[PassType] | PassManager]] |
| 518 | + ) = None, |
| 519 | + edge_manager_transform_passes: ( |
| 520 | + None | List[Callable[[EdgeProgramManager], List[PassType] | PassManager]] |
508 | 521 | ) = None, |
509 | 522 | ) -> None: |
| 523 | + """ |
| 524 | + Initialize the EdgeProgramManagerTransformStage. |
| 525 | +
|
| 526 | + Args: |
| 527 | + edge_manager_transform_passes: List of callables that take EdgeProgramManager |
| 528 | + and return either List[PassType] or PassManager. |
| 529 | + Each callable is applied sequentially, allowing |
| 530 | + backends to control pass ordering and dependencies. |
| 531 | + """ |
510 | 532 | super().__init__() |
511 | | - self._partitioners = partitioners |
512 | | - self._transform_passes = transform_passes |
| 533 | + self._edge_transform_passes = edge_transform_passes or [] |
| 534 | + self._edge_manager_transform_passes = edge_manager_transform_passes or [] |
513 | 535 |
|
514 | 536 | @classmethod |
515 | 537 | def from_recipe( |
516 | | - cls, lowering_recipe: Optional["LoweringRecipe"] |
517 | | - ) -> "ToBackendStage": |
| 538 | + cls, lowering_recipe: Optional[LoweringRecipe] |
| 539 | + ) -> "EdgeProgramManagerTransformStage": |
518 | 540 | if lowering_recipe is None: |
519 | 541 | return cls() |
520 | 542 |
|
521 | 543 | return cls( |
522 | | - partitioners=lowering_recipe.partitioners, |
523 | | - transform_passes=lowering_recipe.edge_transform_passes, |
| 544 | + edge_transform_passes=lowering_recipe.edge_transform_passes, |
| 545 | + edge_manager_transform_passes=lowering_recipe.edge_manager_transform_passes, |
524 | 546 | ) |
525 | 547 |
|
526 | 548 | @property |
527 | 549 | def stage_type(self) -> str: |
528 | | - return StageType.TO_BACKEND |
| 550 | + return StageType.EDGE_PROGRAM_MANAGER_TRANSFORM |
529 | 551 |
|
530 | 552 | @property |
531 | 553 | def valid_predecessor_stages(self) -> List["StageType"]: |
532 | | - return [StageType.TO_EDGE] |
| 554 | + return [ |
| 555 | + StageType.TO_EDGE, |
| 556 | + # StageType.TO_EDGE_TRANSFORM_AND_LOWER, # TODO |
| 557 | + ] |
533 | 558 |
|
534 | 559 | @property |
535 | 560 | def can_start_pipeline(self) -> bool: |
536 | 561 | return False |
537 | 562 |
|
538 | 563 | def run(self, artifact: PipelineArtifact) -> None: |
539 | 564 | """ |
540 | | - Apply transformations and partitioning to EdgeProgramManager. |
| 565 | + Apply transformation passes sequentially. |
541 | 566 |
|
542 | 567 | Args: |
543 | | - artifact: Contains edge program manager and context |
| 568 | + artifact: Pipeline artifact containing EdgeProgramManager |
544 | 569 | """ |
545 | 570 | edge_program_manager = artifact.data |
546 | 571 |
|
547 | | - if edge_program_manager is None: |
548 | | - raise RuntimeError("Edge program manager is not set.") |
| 572 | + if not isinstance(edge_program_manager, EdgeProgramManager): |
| 573 | + raise TypeError( |
| 574 | + f"Expected EdgeProgramManager but got {type(edge_program_manager)}" |
| 575 | + ) |
549 | 576 |
|
550 | | - # per method transform passes |
| 577 | + if not self._edge_transform_passes and not self._edge_manager_transform_passes: |
| 578 | + self._artifact = artifact |
| 579 | + return |
| 580 | + |
| 581 | + # Detect if any callable returns PassManager |
| 582 | + pass_manager = None |
551 | 583 | transform_passes = defaultdict(list) |
552 | 584 | for method_name in edge_program_manager.methods: |
553 | 585 | # Resolve transform passes if it's a callable |
554 | 586 | ep = edge_program_manager.exported_program(method_name) |
555 | | - for pass_ in self._transform_passes or []: |
556 | | - if not callable(pass_): |
| 587 | + for pass_callable in self._edge_transform_passes or []: |
| 588 | + if not callable(pass_callable): |
557 | 589 | raise ValueError( |
558 | | - "Transform passes must be a callable that resolves to a list of passes" |
| 590 | + "Transform passes must be a callable that resolves to passes" |
559 | 591 | ) |
560 | | - passes = pass_(method_name, ep) |
561 | | - if isinstance(passes, list): |
562 | | - transform_passes[method_name].extend(passes) |
| 592 | + passes = pass_callable(method_name, ep) |
| 593 | + if isinstance(passes, PassManager): |
| 594 | + pass_manager = passes |
| 595 | + break |
563 | 596 | else: |
564 | | - raise ValueError("Transform passes must return list of passes") |
| 597 | + transform_passes[method_name].extend(passes) |
| 598 | + if pass_manager: |
| 599 | + break |
| 600 | + |
| 601 | + # Use PassManager directly if found, otherwise use dict |
| 602 | + final_passes = pass_manager if pass_manager else transform_passes |
| 603 | + |
| 604 | + # Apply edge transform passes |
| 605 | + edge_program_manager = edge_program_manager.transform(final_passes) |
| 606 | + |
| 607 | + # Run edge manager transform passes |
| 608 | + for pass_callable in self._edge_manager_transform_passes: |
| 609 | + passes = pass_callable(edge_program_manager) |
| 610 | + if passes: |
| 611 | + edge_program_manager = edge_program_manager.transform(passes) |
| 612 | + |
| 613 | + self._artifact = artifact.copy_with_new_data(edge_program_manager) |
| 614 | + |
| 615 | + |
| 616 | +class ToBackendStage(Stage): |
| 617 | + """ |
| 618 | + Stage: Apply partitioning to EdgeProgramManager. |
| 619 | + """ |
565 | 620 |
|
566 | | - # Apply transform passes |
567 | | - edge_program_manager = edge_program_manager.transform(transform_passes) |
| 621 | + def __init__( |
| 622 | + self, |
| 623 | + partitioners: Optional[List[Any]] = None, |
| 624 | + ) -> None: |
| 625 | + super().__init__() |
| 626 | + self._partitioners = partitioners |
| 627 | + |
| 628 | + @classmethod |
| 629 | + def from_recipe( |
| 630 | + cls, lowering_recipe: Optional["LoweringRecipe"] |
| 631 | + ) -> "ToBackendStage": |
| 632 | + if lowering_recipe is None: |
| 633 | + return cls() |
| 634 | + |
| 635 | + return cls( |
| 636 | + partitioners=lowering_recipe.partitioners, |
| 637 | + ) |
| 638 | + |
| 639 | + @property |
| 640 | + def stage_type(self) -> str: |
| 641 | + return StageType.TO_BACKEND |
| 642 | + |
| 643 | + @property |
| 644 | + def valid_predecessor_stages(self) -> List["StageType"]: |
| 645 | + return [ |
| 646 | + StageType.TO_EDGE, |
| 647 | + StageType.EDGE_PROGRAM_MANAGER_TRANSFORM, |
| 648 | + ] |
| 649 | + |
| 650 | + @property |
| 651 | + def can_start_pipeline(self) -> bool: |
| 652 | + return False |
| 653 | + |
| 654 | + def run(self, artifact: PipelineArtifact) -> None: |
| 655 | + """ |
| 656 | + Apply partitioning to EdgeProgramManager. |
| 657 | +
|
| 658 | + Args: |
| 659 | + artifact: Contains edge program manager and context |
| 660 | + """ |
| 661 | + edge_program_manager = artifact.data |
| 662 | + |
| 663 | + if edge_program_manager is None: |
| 664 | + raise RuntimeError("Edge program manager is not set.") |
568 | 665 |
|
569 | 666 | # Apply partitioners if available |
570 | 667 | if self._partitioners is not None and len(self._partitioners) > 0: |
|
0 commit comments