From 513cf6d49d1c3fd852e099e97b714bbdce43f6e6 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Wed, 30 Jul 2025 23:14:19 -0400 Subject: [PATCH 1/6] adding dialect with value semantics to track state --- .../shuttle/dialects/measure/__init__.py | 2 - src/bloqade/shuttle/dialects/measure/stmts.py | 4 +- src/bloqade/shuttle/dialects/measure/types.py | 8 +- src/bloqade/shuttle/dialects/spec/stmts.py | 2 +- .../shuttle/dialects/tracking/__init__.py | 1 + .../shuttle/dialects/tracking/_dialect.py | 3 + .../shuttle/dialects/tracking/stmts.py | 120 ++++++++++++++++++ .../shuttle/dialects/tracking/types.py | 8 ++ src/bloqade/shuttle/passes/path2tracking.py | 59 +++++++++ 9 files changed, 195 insertions(+), 12 deletions(-) create mode 100644 src/bloqade/shuttle/dialects/tracking/__init__.py create mode 100644 src/bloqade/shuttle/dialects/tracking/_dialect.py create mode 100644 src/bloqade/shuttle/dialects/tracking/stmts.py create mode 100644 src/bloqade/shuttle/dialects/tracking/types.py create mode 100644 src/bloqade/shuttle/passes/path2tracking.py diff --git a/src/bloqade/shuttle/dialects/measure/__init__.py b/src/bloqade/shuttle/dialects/measure/__init__.py index 9b1f59b5..3c765d49 100644 --- a/src/bloqade/shuttle/dialects/measure/__init__.py +++ b/src/bloqade/shuttle/dialects/measure/__init__.py @@ -4,6 +4,4 @@ from .types import ( MeasurementArray as MeasurementArray, MeasurementArrayType as MeasurementArrayType, - MeasurementResult as MeasurementResult, - MeasurementResultType as MeasurementResultType, ) diff --git a/src/bloqade/shuttle/dialects/measure/stmts.py b/src/bloqade/shuttle/dialects/measure/stmts.py index 218b0bd4..d5981b7f 100644 --- a/src/bloqade/shuttle/dialects/measure/stmts.py +++ b/src/bloqade/shuttle/dialects/measure/stmts.py @@ -1,12 +1,13 @@ from typing import cast from bloqade.geometry.dialects import grid +from bloqade.squin.qubit import MeasurementResultType from kirin import decl, ir, lowering, types from kirin.decl import info from kirin.dialects import ilist from ._dialect import dialect -from .types import MeasurementArrayType, MeasurementResultType +from .types import MeasurementArrayType NumX = types.TypeVar("NumX") NumY = types.TypeVar("NumY") @@ -19,7 +20,6 @@ class Measure(ir.Statement): name = "measure" traits = frozenset({lowering.FromPythonCall()}) - grids: tuple[ir.SSAValue, ...] = info.argument(grid.GridType[NumX, NumY]) def __init__(self, grids: tuple[ir.SSAValue, ...]): diff --git a/src/bloqade/shuttle/dialects/measure/types.py b/src/bloqade/shuttle/dialects/measure/types.py index e4e11016..8d604169 100644 --- a/src/bloqade/shuttle/dialects/measure/types.py +++ b/src/bloqade/shuttle/dialects/measure/types.py @@ -1,13 +1,8 @@ from typing import Generic, TypeVar +from bloqade.squin.qubit import MeasurementResult from kirin import types - -# TODO: replace this with the squin dialect's MeasurementResultType -class MeasurementResult: - pass - - NumRows = TypeVar("NumRows") NumCols = TypeVar("NumCols") @@ -23,7 +18,6 @@ def __getitem__(self, indices: tuple[int, int]) -> MeasurementResult: ) -MeasurementResultType = types.PyClass(MeasurementResult) MeasurementArrayType = types.Generic( MeasurementArray, types.TypeVar("NumRows"), types.TypeVar("NumCols") ) diff --git a/src/bloqade/shuttle/dialects/spec/stmts.py b/src/bloqade/shuttle/dialects/spec/stmts.py index b4890932..a0505404 100644 --- a/src/bloqade/shuttle/dialects/spec/stmts.py +++ b/src/bloqade/shuttle/dialects/spec/stmts.py @@ -2,7 +2,7 @@ from kirin import ir, lowering, types from kirin.decl import info, statement -from bloqade.shuttle.dialects.spec._dialect import dialect +from ._dialect import dialect @statement(dialect=dialect) diff --git a/src/bloqade/shuttle/dialects/tracking/__init__.py b/src/bloqade/shuttle/dialects/tracking/__init__.py new file mode 100644 index 00000000..cc6869c5 --- /dev/null +++ b/src/bloqade/shuttle/dialects/tracking/__init__.py @@ -0,0 +1 @@ +from ._dialect import dialect as dialect diff --git a/src/bloqade/shuttle/dialects/tracking/_dialect.py b/src/bloqade/shuttle/dialects/tracking/_dialect.py new file mode 100644 index 00000000..d2ef8a4c --- /dev/null +++ b/src/bloqade/shuttle/dialects/tracking/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("bloqade.shuttle.state") diff --git a/src/bloqade/shuttle/dialects/tracking/stmts.py b/src/bloqade/shuttle/dialects/tracking/stmts.py new file mode 100644 index 00000000..c886fa0e --- /dev/null +++ b/src/bloqade/shuttle/dialects/tracking/stmts.py @@ -0,0 +1,120 @@ +from typing import cast + +from bloqade.geometry.dialects import grid +from kirin import ir, types +from kirin.decl import info, statement +from kirin.dialects import ilist + +from .. import measure, path as path_dialect +from ._dialect import dialect +from .types import SystemStateType + + +@statement(dialect=dialect) +class Fill(ir.Statement): + name = "fill" + + traits = frozenset({}) + + locations: ir.SSAValue = info.argument( + ilist.IListType[grid.GridType[types.Any, types.Any], types.Any] + ) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class Play(ir.Statement): + name = "play" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + path: ir.SSAValue = info.argument(path_dialect.PathType) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class TopHatCZ(ir.Statement): + name = "tophat_cz" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + zone: ir.SSAValue = info.argument(grid.GridType[types.Any, types.Any]) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class GlobalR(ir.Statement): + name = "global_r" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + axis_angle: ir.SSAValue = info.argument(types.Float) + rotation_angle: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class LocalR(ir.Statement): + name = "local_r" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + axis_angle: ir.SSAValue = info.argument(types.Float) + rotation_angle: ir.SSAValue = info.argument(types.Float) + zone: ir.SSAValue = info.argument(grid.GridType[types.Any, types.Any]) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class GlobalRz(ir.Statement): + name = "global_rz" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + rotation_angle: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class LocalRz(ir.Statement): + name = "local_rz" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + rotation_angle: ir.SSAValue = info.argument(types.Float) + zone: ir.SSAValue = info.argument(grid.GridType[types.Any, types.Any]) + result: ir.ResultValue = info.result(SystemStateType) + + +@statement(dialect=dialect) +class Measure(ir.Statement): + name = "measure" + + traits = frozenset({}) + + state: ir.SSAValue = info.argument(SystemStateType) + grids: tuple[ir.SSAValue, ...] = info.argument(grid.GridType[types.Any, types.Any]) + + def __init__(self, state: ir.SSAValue, grids: tuple[ir.SSAValue, ...]): + result_types: list[types.TypeAttribute] = [SystemStateType] + + for grid_ssa in grids: + grid_type = grid_ssa.type + if (grid_type := cast(types.Generic, grid_type)).is_subseteq(grid.GridType): + NumX, NumY = grid_type.vars + else: + NumX, NumY = types.Any, types.Any + + result_types.append(measure.MeasurementArrayType[NumX, NumY]) + + super().__init__( + args=(state,) + grids, + result_types=tuple(result_types), + args_slice={"state": 0, "grids": slice(1, len(grids) + 1)}, + ) diff --git a/src/bloqade/shuttle/dialects/tracking/types.py b/src/bloqade/shuttle/dialects/tracking/types.py new file mode 100644 index 00000000..d4f1bc11 --- /dev/null +++ b/src/bloqade/shuttle/dialects/tracking/types.py @@ -0,0 +1,8 @@ +from kirin import types + + +class SystemState: + pass + + +SystemStateType = types.PyClass(SystemState) diff --git a/src/bloqade/shuttle/passes/path2tracking.py b/src/bloqade/shuttle/passes/path2tracking.py new file mode 100644 index 00000000..777b852e --- /dev/null +++ b/src/bloqade/shuttle/passes/path2tracking.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field + +from kirin import ir +from kirin.dialects import cf, func +from kirin.rewrite.abc import RewriteResult, RewriteRule + +from bloqade.shuttle.dialects import gate, init, measure, path +from bloqade.shuttle.dialects.tracking.types import SystemStateType + + +@dataclass +class Path2TrackingRewrite(RewriteRule): + entry_code: func.Function + curr_state: ir.SSAValue | None = None + callgraph: dict[ir.Method, ir.Method] = field(default_factory=dict) + + stmt_types = ( + path.Play, + gate.TopHatCZ, + gate.GlobalR, + gate.LocalR, + measure.Measure, + cf.ConditionalBranch, + cf.Branch, + func.Function, + func.Return, + func.Lambda, + init.Fill, + ) + + def default_rewrite(self, node: ir.Statement) -> RewriteResult: + raise RuntimeError(f"missing rewrite method for statement type {type(node)!r}") + + def rewrite_Block(self, node: ir.Block) -> RewriteResult: + if (region := node.parent_node) is None: + return RewriteResult() + + if node.parent_node is self.entry_code and (region._block_idx[node] == 0): + return RewriteResult() + + if self.curr_state is None: + return RewriteResult() + + self.curr_state = node.args.insert_from(0, SystemStateType, "system_state") + return RewriteResult(has_done_something=True) + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, self.stmt_types): + return RewriteResult() + + method = getattr(self, f"rewrite_{type(node).__name__}", self.default_rewrite) + return method(node) + + def rewrite_Fill(self, node: init.Fill) -> RewriteResult: + assert ( + self.curr_state is not None + ), "curr_state should not be set before Fill is rewritten" + + return RewriteResult() From 84775182015d5bcb65e675ef4852e149e48c2b90 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 1 Aug 2025 11:40:51 -0400 Subject: [PATCH 2/6] adding tracking dialect and analysis --- .../shuttle/analysis/aod_state/__init__.py | 0 .../shuttle/analysis/aod_state/analysis.py | 40 +++ .../shuttle/analysis/aod_state/impl.py | 30 +++ .../shuttle/analysis/aod_state/lattice.py | 84 ++++++ .../shuttle/dialects/tracking/__init__.py | 12 + src/bloqade/shuttle/passes/path2tracking.py | 251 +++++++++++++++++- 6 files changed, 406 insertions(+), 11 deletions(-) create mode 100644 src/bloqade/shuttle/analysis/aod_state/__init__.py create mode 100644 src/bloqade/shuttle/analysis/aod_state/analysis.py create mode 100644 src/bloqade/shuttle/analysis/aod_state/impl.py create mode 100644 src/bloqade/shuttle/analysis/aod_state/lattice.py diff --git a/src/bloqade/shuttle/analysis/aod_state/__init__.py b/src/bloqade/shuttle/analysis/aod_state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/shuttle/analysis/aod_state/analysis.py b/src/bloqade/shuttle/analysis/aod_state/analysis.py new file mode 100644 index 00000000..48c3f2fe --- /dev/null +++ b/src/bloqade/shuttle/analysis/aod_state/analysis.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass + +from kirin import interp, ir +from kirin.analysis.forward import Forward, ForwardFrame + +from bloqade.shuttle.dialects import tracking + +from .lattice import AODState + + +@dataclass +class AODStateAnalysis(Forward[AODState]): + + keys = ["aod.analysis"] + lattice = AODState + + max_x_tones: int + max_y_tones: int + + def is_pure(self, stmt: ir.Statement) -> bool: + # Check if the statement is pure by looking at its attributes + + return ( + stmt.has_trait(ir.Pure) + or (maybe_pure := stmt.get_trait(ir.MaybePure)) is not None + and maybe_pure.is_pure(stmt) + ) + + def eval_stmt_fallback( + self, frame: ForwardFrame[AODState], stmt: ir.Statement + ) -> tuple[AODState, ...] | interp.SpecialValue[AODState]: + return tuple( + AODState.top() + for result in stmt.results + if result.type.is_subseteq(tracking.SystemStateType) + ) + + def run_method(self, method: ir.Method, args: tuple[AODState, ...]): + # NOTE: we do not support dynamic calls here, thus no need to propagate method object + return self.run_callable(method.code, (self.lattice.bottom(),) + args) diff --git a/src/bloqade/shuttle/analysis/aod_state/impl.py b/src/bloqade/shuttle/analysis/aod_state/impl.py new file mode 100644 index 00000000..59b25d24 --- /dev/null +++ b/src/bloqade/shuttle/analysis/aod_state/impl.py @@ -0,0 +1,30 @@ +from kirin import interp +from kirin.analysis.forward import ForwardFrame +from kirin.dialects import scf + +from .analysis import AODStateAnalysis +from .lattice import AODState, PythonRuntime + + +class ScfMethods(scf.absint.Methods): + + @interp.impl(scf.IfElse) + def if_else( + self, _interp: AODStateAnalysis, frame: ForwardFrame[AODState], stmt: scf.IfElse + ): + cond = frame.get(stmt.cond) + + if isinstance(cond, PythonRuntime): + if cond.value: + with _interp.new_frame(stmt, has_parent_access=True) as new_frame: + return _interp.run_ssacfg_region( + new_frame, stmt.then_body, (PythonRuntime(True),) + ) + else: + with _interp.new_frame(stmt, has_parent_access=True) as new_frame: + return _interp.run_ssacfg_region( + new_frame, stmt.else_body, (PythonRuntime(False),) + ) + + else: + super().if_else(_interp, frame, stmt) diff --git a/src/bloqade/shuttle/analysis/aod_state/lattice.py b/src/bloqade/shuttle/analysis/aod_state/lattice.py new file mode 100644 index 00000000..670357ec --- /dev/null +++ b/src/bloqade/shuttle/analysis/aod_state/lattice.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass +from typing import Any + +from bloqade.geometry.dialects import grid +from kirin import ir +from kirin.ir.attrs.abc import LatticeAttributeMeta +from kirin.lattice.abc import BoundedLattice +from kirin.lattice.mixin import SimpleJoinMixin, SimpleMeetMixin +from kirin.print.printer import Printer + + +@dataclass +class AODState( + ir.Attribute, + SimpleJoinMixin["AODState"], + SimpleMeetMixin["AODState"], + BoundedLattice["AODState"], + metaclass=LatticeAttributeMeta, +): + + @classmethod + def bottom(cls) -> "AODState": + return NotAOD() + + @classmethod + def top(cls) -> "AODState": + return Unknown() + + def print_impl(self, printer: Printer) -> None: + printer.print(self.__class__.__name__ + "()") + + +@dataclass +class NotAOD(AODState): + + def is_subseteq(self, other: AODState) -> bool: + return True + + +@dataclass +class Unknown(AODState): + def is_subseteq(self, other: AODState) -> bool: + return isinstance(other, Unknown) + + +@dataclass +class PythonRuntime(AODState): + """ + AODState that represents a Python runtime value. + This is used to represent values that are not known at compile time. + """ + + value: Any + + def is_subseteq(self, other: AODState) -> bool: + return isinstance(other, PythonRuntime) and self.value == other.value + + +@dataclass +class AOD(AODState): + x_tones: frozenset[int] + y_tones: frozenset[int] + pos: grid.Grid + + def is_subseteq(self, other: AODState) -> bool: + return self == other + + +@dataclass +class AODCollision(AODState): + x_tones: dict[int, int] + y_tones: dict[int, int] + + def is_subseteq(self, other: AODState) -> bool: + return self == other + + +@dataclass +class AODJump(AODState): + x_tones: dict[int, tuple[float, float]] + y_tones: dict[int, tuple[float, float]] + + def is_subseteq(self, other: AODState) -> bool: + return self == other diff --git a/src/bloqade/shuttle/dialects/tracking/__init__.py b/src/bloqade/shuttle/dialects/tracking/__init__.py index cc6869c5..d3b53923 100644 --- a/src/bloqade/shuttle/dialects/tracking/__init__.py +++ b/src/bloqade/shuttle/dialects/tracking/__init__.py @@ -1 +1,13 @@ from ._dialect import dialect as dialect +from .stmts import ( + Fill as Fill, + GlobalR as GlobalR, + LocalR as LocalR, + Measure as Measure, + Play as Play, + TopHatCZ as TopHatCZ, +) +from .types import ( + SystemState as SystemState, + SystemStateType as SystemStateType, +) diff --git a/src/bloqade/shuttle/passes/path2tracking.py b/src/bloqade/shuttle/passes/path2tracking.py index 777b852e..82b5043e 100644 --- a/src/bloqade/shuttle/passes/path2tracking.py +++ b/src/bloqade/shuttle/passes/path2tracking.py @@ -1,18 +1,19 @@ from dataclasses import dataclass, field -from kirin import ir -from kirin.dialects import cf, func +from kirin import ir, rewrite, types +from kirin.dialects import cf, func, scf +from kirin.passes import Pass from kirin.rewrite.abc import RewriteResult, RewriteRule -from bloqade.shuttle.dialects import gate, init, measure, path +from bloqade.shuttle.dialects import gate, init, measure, path, tracking from bloqade.shuttle.dialects.tracking.types import SystemStateType @dataclass -class Path2TrackingRewrite(RewriteRule): +class PathToTrackingRewrite(RewriteRule): entry_code: func.Function curr_state: ir.SSAValue | None = None - callgraph: dict[ir.Method, ir.Method] = field(default_factory=dict) + call_graph: dict[ir.Method, ir.Method] = field(default_factory=dict) stmt_types = ( path.Play, @@ -20,25 +21,29 @@ class Path2TrackingRewrite(RewriteRule): gate.GlobalR, gate.LocalR, measure.Measure, + init.Fill, cf.ConditionalBranch, cf.Branch, + scf.For, + scf.IfElse, + scf.Yield, func.Function, func.Return, func.Lambda, - init.Fill, + func.Invoke, ) def default_rewrite(self, node: ir.Statement) -> RewriteResult: raise RuntimeError(f"missing rewrite method for statement type {type(node)!r}") def rewrite_Block(self, node: ir.Block) -> RewriteResult: - if (region := node.parent_node) is None: + if self.curr_state is None: return RewriteResult() - if node.parent_node is self.entry_code and (region._block_idx[node] == 0): + if (region := node.parent_node) is None: return RewriteResult() - if self.curr_state is None: + if node.parent_node is self.entry_code and (region._block_idx[node] == 0): return RewriteResult() self.curr_state = node.args.insert_from(0, SystemStateType, "system_state") @@ -53,7 +58,231 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_Fill(self, node: init.Fill) -> RewriteResult: assert ( - self.curr_state is not None + self.curr_state is None ), "curr_state should not be set before Fill is rewritten" - return RewriteResult() + node.replace_by(new_node := tracking.Fill(node.locations)) + + self.curr_state = new_node.result + + return RewriteResult(has_done_something=True) + + def rewrite_Play(self, node: path.Play) -> RewriteResult: + assert ( + self.curr_state is not None + ), "curr_state should be set before Play is rewritten" + + node.replace_by( + new_node := tracking.Play(state=self.curr_state, path=node.path) + ) + self.curr_state = new_node.result + return RewriteResult(has_done_something=True) + + def rewrite_TopHatCZ(self, node: gate.TopHatCZ) -> RewriteResult: + assert ( + self.curr_state is not None + ), "curr_state should be set before TopHatCZ is rewritten" + + node.replace_by( + new_node := tracking.TopHatCZ(state=self.curr_state, zone=node.zone) + ) + self.curr_state = new_node.result + + return RewriteResult(has_done_something=True) + + def rewrite_GlobalR(self, node: gate.GlobalR) -> RewriteResult: + assert ( + self.curr_state is not None + ), "curr_state should be set before GlobalR is rewritten" + + node.replace_by( + new_node := tracking.GlobalR( + state=self.curr_state, + axis_angle=node.axis_angle, + rotation_angle=node.rotation_angle, + ) + ) + self.curr_state = new_node.result + + return RewriteResult(has_done_something=True) + + def rewrite_LocalR(self, node: gate.LocalR) -> RewriteResult: + assert ( + self.curr_state is not None + ), "curr_state should be set before LocalR is rewritten" + + node.replace_by( + new_node := tracking.LocalR( + state=self.curr_state, + axis_angle=node.axis_angle, + rotation_angle=node.rotation_angle, + zone=node.zone, + ) + ) + self.curr_state = new_node.result + + return RewriteResult(has_done_something=True) + + def rewrite_Measure(self, node: measure.Measure) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.insert_before( + new_node := tracking.Measure(state=self.curr_state, grids=node.grids) + ) + self.curr_state = new_node.results[0] + + for old_result, new_result in zip(node.results, new_node.results[1:]): + old_result.replace_by(new_result) + + node.delete() + + return RewriteResult(has_done_something=True) + + def rewrite_Branch(self, node: cf.Branch) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.replace_by( + cf.Branch( + arguments=(self.curr_state, *node.arguments), + successor=node.successor, + ) + ) + return RewriteResult(has_done_something=True) + + def rewrite_ConditionalBranch(self, node: cf.ConditionalBranch) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.replace_by( + cf.ConditionalBranch( + node.cond, + (self.curr_state, *node.then_arguments), + (self.curr_state, *node.else_arguments), + then_successor=node.then_successor, + else_successor=node.else_successor, + ) + ) + return RewriteResult(has_done_something=True) + + def rewrite_For(self, node: scf.For) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.replace_by( + scf.For( + node.iterable, + node.body, # body will be rewritten in the `rewrite_Block` method + self.curr_state, + *node.initializers, + ) + ) + return RewriteResult(has_done_something=True) + + def rewrite_Yield(self, node: scf.Yield) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.replace_by( + scf.Yield( + self.curr_state, + *node.values, + ) + ) + + return RewriteResult(has_done_something=True) + + def rewrite_Return(self, node: func.Return) -> RewriteResult: + if node.parent_stmt is self.entry_code or node.parent_stmt is None: + return RewriteResult() + + raise RuntimeError("missing rewrite method for Return statement") + + def rewrite_Function(self, node: func.Function) -> RewriteResult: + if node is self.entry_code: + return RewriteResult() + + old_signature = node.signature + new_signature = func.Signature( + (SystemStateType, *old_signature.inputs), + types.Tuple[SystemStateType, old_signature.output], + ) + + node.replace_by( + func.Function( + sym_name=node.sym_name, + signature=new_signature, + body=node.body, + ) + ) + + return RewriteResult(has_done_something=True) + + def rewrite_Lambda(self, node: func.Lambda) -> RewriteResult: + if node is self.entry_code: + return RewriteResult() + + old_signature = node.signature + new_signature = func.Signature( + (SystemStateType, *old_signature.inputs), + types.Tuple[SystemStateType, old_signature.output], + ) + + node.replace_by( + func.Lambda( + node.captured, + sym_name=node.sym_name, + signature=new_signature, + body=node.body, + ) + ) + + return RewriteResult(has_done_something=True) + + def rewrite_Call(self, node: func.Call) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + node.replace_by( + func.Call( + node.callee, + (self.curr_state, *node.inputs), + kwargs=node.kwargs, + ) + ) + + return RewriteResult(has_done_something=True) + + def rewrite_Invoke(self, node: func.Invoke) -> RewriteResult: + if self.curr_state is None: + return RewriteResult() + + callee = node.callee + if callee not in self.call_graph: + new_callee = callee.similar() + self.call_graph[callee] = new_callee + new_callee.arg_names = ["system_state", *callee.arg_names] + + rewrite.Walk(self).rewrite(new_callee.code) + else: + new_callee = self.call_graph[callee] + + node.replace_by( + func.Invoke( + (self.curr_state, *node.inputs), + callee=new_callee, + kwargs=node.kwargs, + ) + ) + + return RewriteResult(has_done_something=True) + + +@dataclass +class PathToTracking(Pass): + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + if not isinstance(mt.code, func.Function): + return RewriteResult() + + return rewrite.Walk(PathToTrackingRewrite(mt.code)).rewrite(mt.code) From a038d96f13f74e0ef7d3284218c69ba57f9190f4 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 1 Aug 2025 13:42:56 -0400 Subject: [PATCH 3/6] renaming analysis --- src/bloqade/shuttle/analysis/aod/__init__.py | 7 ++++ .../analysis/{aod_state => aod}/analysis.py | 13 +++----- .../analysis/{aod_state => aod}/lattice.py | 32 ------------------- .../shuttle/analysis/aod_state/impl.py | 30 ----------------- .../__init__.py => dialects/path/aod.py} | 0 5 files changed, 11 insertions(+), 71 deletions(-) create mode 100644 src/bloqade/shuttle/analysis/aod/__init__.py rename src/bloqade/shuttle/analysis/{aod_state => aod}/analysis.py (71%) rename src/bloqade/shuttle/analysis/{aod_state => aod}/lattice.py (61%) delete mode 100644 src/bloqade/shuttle/analysis/aod_state/impl.py rename src/bloqade/shuttle/{analysis/aod_state/__init__.py => dialects/path/aod.py} (100%) diff --git a/src/bloqade/shuttle/analysis/aod/__init__.py b/src/bloqade/shuttle/analysis/aod/__init__.py new file mode 100644 index 00000000..fde8e8ad --- /dev/null +++ b/src/bloqade/shuttle/analysis/aod/__init__.py @@ -0,0 +1,7 @@ +from .analysis import AODAnalysis as AODAnalysis +from .lattice import ( + AOD as AOD, + AODState as AODState, + NotAOD as NotAOD, + Unknown as Unknown, +) diff --git a/src/bloqade/shuttle/analysis/aod_state/analysis.py b/src/bloqade/shuttle/analysis/aod/analysis.py similarity index 71% rename from src/bloqade/shuttle/analysis/aod_state/analysis.py rename to src/bloqade/shuttle/analysis/aod/analysis.py index 48c3f2fe..ecd2c71b 100644 --- a/src/bloqade/shuttle/analysis/aod_state/analysis.py +++ b/src/bloqade/shuttle/analysis/aod/analysis.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Any from kirin import interp, ir from kirin.analysis.forward import Forward, ForwardFrame @@ -9,7 +10,7 @@ @dataclass -class AODStateAnalysis(Forward[AODState]): +class AODAnalysis(Forward[AODState]): keys = ["aod.analysis"] lattice = AODState @@ -17,14 +18,8 @@ class AODStateAnalysis(Forward[AODState]): max_x_tones: int max_y_tones: int - def is_pure(self, stmt: ir.Statement) -> bool: - # Check if the statement is pure by looking at its attributes - - return ( - stmt.has_trait(ir.Pure) - or (maybe_pure := stmt.get_trait(ir.MaybePure)) is not None - and maybe_pure.is_pure(stmt) - ) + def get_const_value(self, typ, ssa: ir.SSAValue) -> Any | None: + raise NotImplementedError def eval_stmt_fallback( self, frame: ForwardFrame[AODState], stmt: ir.Statement diff --git a/src/bloqade/shuttle/analysis/aod_state/lattice.py b/src/bloqade/shuttle/analysis/aod/lattice.py similarity index 61% rename from src/bloqade/shuttle/analysis/aod_state/lattice.py rename to src/bloqade/shuttle/analysis/aod/lattice.py index 670357ec..2d1c901f 100644 --- a/src/bloqade/shuttle/analysis/aod_state/lattice.py +++ b/src/bloqade/shuttle/analysis/aod/lattice.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Any from bloqade.geometry.dialects import grid from kirin import ir @@ -43,19 +42,6 @@ def is_subseteq(self, other: AODState) -> bool: return isinstance(other, Unknown) -@dataclass -class PythonRuntime(AODState): - """ - AODState that represents a Python runtime value. - This is used to represent values that are not known at compile time. - """ - - value: Any - - def is_subseteq(self, other: AODState) -> bool: - return isinstance(other, PythonRuntime) and self.value == other.value - - @dataclass class AOD(AODState): x_tones: frozenset[int] @@ -64,21 +50,3 @@ class AOD(AODState): def is_subseteq(self, other: AODState) -> bool: return self == other - - -@dataclass -class AODCollision(AODState): - x_tones: dict[int, int] - y_tones: dict[int, int] - - def is_subseteq(self, other: AODState) -> bool: - return self == other - - -@dataclass -class AODJump(AODState): - x_tones: dict[int, tuple[float, float]] - y_tones: dict[int, tuple[float, float]] - - def is_subseteq(self, other: AODState) -> bool: - return self == other diff --git a/src/bloqade/shuttle/analysis/aod_state/impl.py b/src/bloqade/shuttle/analysis/aod_state/impl.py deleted file mode 100644 index 59b25d24..00000000 --- a/src/bloqade/shuttle/analysis/aod_state/impl.py +++ /dev/null @@ -1,30 +0,0 @@ -from kirin import interp -from kirin.analysis.forward import ForwardFrame -from kirin.dialects import scf - -from .analysis import AODStateAnalysis -from .lattice import AODState, PythonRuntime - - -class ScfMethods(scf.absint.Methods): - - @interp.impl(scf.IfElse) - def if_else( - self, _interp: AODStateAnalysis, frame: ForwardFrame[AODState], stmt: scf.IfElse - ): - cond = frame.get(stmt.cond) - - if isinstance(cond, PythonRuntime): - if cond.value: - with _interp.new_frame(stmt, has_parent_access=True) as new_frame: - return _interp.run_ssacfg_region( - new_frame, stmt.then_body, (PythonRuntime(True),) - ) - else: - with _interp.new_frame(stmt, has_parent_access=True) as new_frame: - return _interp.run_ssacfg_region( - new_frame, stmt.else_body, (PythonRuntime(False),) - ) - - else: - super().if_else(_interp, frame, stmt) diff --git a/src/bloqade/shuttle/analysis/aod_state/__init__.py b/src/bloqade/shuttle/dialects/path/aod.py similarity index 100% rename from src/bloqade/shuttle/analysis/aod_state/__init__.py rename to src/bloqade/shuttle/dialects/path/aod.py From 9eab301ee3520964110eaff759f99c7f938a19f0 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 4 Aug 2025 13:33:42 -0400 Subject: [PATCH 4/6] adding forward dataflow arch spec base class --- src/bloqade/shuttle/analysis/aod/analysis.py | 25 +++++++++++++------ src/bloqade/shuttle/analysis/aod/lattice.py | 14 +++++++++-- src/bloqade/shuttle/arch.py | 4 ++- src/bloqade/shuttle/dialects/path/aod.py | 0 src/bloqade/shuttle/dialects/path/concrete.py | 4 ++- .../shuttle/dialects/path/constprop.py | 2 +- .../shuttle/dialects/path/spec_interp.py | 2 +- test/codegen/test_taskgen.py | 4 ++- test/unit/codegen/test_taskgen.py | 6 ++--- 9 files changed, 44 insertions(+), 17 deletions(-) delete mode 100644 src/bloqade/shuttle/dialects/path/aod.py diff --git a/src/bloqade/shuttle/analysis/aod/analysis.py b/src/bloqade/shuttle/analysis/aod/analysis.py index ecd2c71b..d6aa3678 100644 --- a/src/bloqade/shuttle/analysis/aod/analysis.py +++ b/src/bloqade/shuttle/analysis/aod/analysis.py @@ -1,25 +1,36 @@ from dataclasses import dataclass -from typing import Any +from typing import Type, TypeVar from kirin import interp, ir +from kirin.analysis import const from kirin.analysis.forward import Forward, ForwardFrame +from bloqade.shuttle.arch import ArchSpecMixin from bloqade.shuttle.dialects import tracking from .lattice import AODState @dataclass -class AODAnalysis(Forward[AODState]): +class AODAnalysis(Forward[AODState], ArchSpecMixin): - keys = ["aod.analysis"] + keys = ["aod.analysis", "spec.interp"] lattice = AODState - max_x_tones: int - max_y_tones: int + T = TypeVar("T") - def get_const_value(self, typ, ssa: ir.SSAValue) -> Any | None: - raise NotImplementedError + def get_const_value(self, typ: Type[T], ssa: ir.SSAValue) -> T: + if not isinstance(value := ssa.hints.get("const"), const.Value): + raise interp.InterpreterError( + "Non-constant value encountered in AOD analysis." + ) + + if not isinstance(data := value.data, typ): + raise interp.InterpreterError( + f"Expected constant of type {typ}, got {type(data)}." + ) + + return data def eval_stmt_fallback( self, frame: ForwardFrame[AODState], stmt: ir.Statement diff --git a/src/bloqade/shuttle/analysis/aod/lattice.py b/src/bloqade/shuttle/analysis/aod/lattice.py index 2d1c901f..3c6d5bd4 100644 --- a/src/bloqade/shuttle/analysis/aod/lattice.py +++ b/src/bloqade/shuttle/analysis/aod/lattice.py @@ -2,6 +2,7 @@ from bloqade.geometry.dialects import grid from kirin import ir +from kirin.dialects import ilist from kirin.ir.attrs.abc import LatticeAttributeMeta from kirin.lattice.abc import BoundedLattice from kirin.lattice.mixin import SimpleJoinMixin, SimpleMeetMixin @@ -31,7 +32,6 @@ def print_impl(self, printer: Printer) -> None: @dataclass class NotAOD(AODState): - def is_subseteq(self, other: AODState) -> bool: return True @@ -48,5 +48,15 @@ class AOD(AODState): y_tones: frozenset[int] pos: grid.Grid + def active_positions(self) -> grid.Grid: + x_indices = ilist.IList(sorted(self.x_tones)) + y_indices = ilist.IList(sorted(self.y_tones)) + return self.pos.get_view(x_indices, y_indices) + def is_subseteq(self, other: AODState) -> bool: - return self == other + # only check of the active AOD positions are equal not necessarily + # the exact positions + return ( + isinstance(other, AOD) + and self.active_positions() == other.active_positions() + ) diff --git a/src/bloqade/shuttle/arch.py b/src/bloqade/shuttle/arch.py index 469d06e1..4b6983e7 100644 --- a/src/bloqade/shuttle/arch.py +++ b/src/bloqade/shuttle/arch.py @@ -77,13 +77,15 @@ def _default_layout(): @dataclass(frozen=True) class ArchSpec: layout: Layout = field(default_factory=_default_layout) # type: ignore + max_x_tones: int = 16 + max_y_tones: int = 16 @dataclass class ArchSpecMixin: """Base class for interpreters that require an architecture specification.""" - arch_spec: ArchSpec + arch_spec: ArchSpec = field(kw_only=True) @dataclass diff --git a/src/bloqade/shuttle/dialects/path/aod.py b/src/bloqade/shuttle/dialects/path/aod.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/bloqade/shuttle/dialects/path/concrete.py b/src/bloqade/shuttle/dialects/path/concrete.py index 6eafae94..8d17e958 100644 --- a/src/bloqade/shuttle/dialects/path/concrete.py +++ b/src/bloqade/shuttle/dialects/path/concrete.py @@ -31,7 +31,9 @@ def gen(self, interp: Interpreter, frame: Frame, stmt: stmts.Gen): inputs = frame.get_values(stmt.inputs) kwargs = stmt.kwargs args = interp.permute_values(device_task.move_fn.arg_names, inputs, kwargs) - path = TraceInterpreter(stmt.arch_spec).run_trace(device_task.move_fn, args, {}) + path = TraceInterpreter(arch_spec=stmt.arch_spec).run_trace( + device_task.move_fn, args, {} + ) if reverse: path = reverse_path(path) diff --git a/src/bloqade/shuttle/dialects/path/constprop.py b/src/bloqade/shuttle/dialects/path/constprop.py index 565c5c05..566c6119 100644 --- a/src/bloqade/shuttle/dialects/path/constprop.py +++ b/src/bloqade/shuttle/dialects/path/constprop.py @@ -44,7 +44,7 @@ def gen( ) try: - path = TraceInterpreter(stmt.arch_spec).run_trace( + path = TraceInterpreter(arch_spec=stmt.arch_spec).run_trace( device_task.move_fn, tuple( cast(const.Value, arg).data if isinstance(arg, const.Value) else arg diff --git a/src/bloqade/shuttle/dialects/path/spec_interp.py b/src/bloqade/shuttle/dialects/path/spec_interp.py index 59842242..bf30cb0a 100644 --- a/src/bloqade/shuttle/dialects/path/spec_interp.py +++ b/src/bloqade/shuttle/dialects/path/spec_interp.py @@ -29,7 +29,7 @@ def gen(self, interp: ArchSpecInterpreter, frame: Frame, stmt: stmts.Gen): inputs = frame.get_values(stmt.inputs) kwargs = stmt.kwargs args = interp.permute_values(device_task.move_fn.arg_names, inputs, kwargs) - path = TraceInterpreter(interp.arch_spec).run_trace( + path = TraceInterpreter(arch_spec=interp.arch_spec).run_trace( device_task.move_fn, args, {} ) diff --git a/test/codegen/test_taskgen.py b/test/codegen/test_taskgen.py index 4e52ee61..a1812060 100644 --- a/test/codegen/test_taskgen.py +++ b/test/codegen/test_taskgen.py @@ -25,7 +25,9 @@ def move_fn(x: float, y: float): move_fn.print() - action_list = TraceInterpreter(ArchSpec()).run_trace(move_fn, (1.0, 2.0), {}) + action_list = TraceInterpreter(arch_spec=ArchSpec()).run_trace( + move_fn, (1.0, 2.0), {} + ) assert isinstance(action_list, list) diff --git a/test/unit/codegen/test_taskgen.py b/test/unit/codegen/test_taskgen.py index 3974714a..e9dca879 100644 --- a/test/unit/codegen/test_taskgen.py +++ b/test/unit/codegen/test_taskgen.py @@ -83,7 +83,7 @@ def start_pos(self): return grid.Grid.from_positions([1, 2], [3, 4]) def init_interpreter(self): - interpreter = taskgen.TraceInterpreter(ArchSpec()) + interpreter = taskgen.TraceInterpreter(arch_spec=ArchSpec()) interpreter.initialize() return interpreter @@ -230,7 +230,7 @@ def test_action(self): action.turn_on(action.ALL, action.ALL) action.turn_off(action.ALL, action.ALL) - interpreter = taskgen.TraceInterpreter(ArchSpec()) + interpreter = taskgen.TraceInterpreter(arch_spec=ArchSpec()) assert interpreter.run_trace(test_action, (), {}) == [ taskgen.WayPointsAction([grid.Grid.from_positions([1, 2], [3, 4])]), taskgen.TurnOnXYSliceAction(action.ALL, action.ALL), @@ -245,7 +245,7 @@ def test_interpreter_run_trace_error(): def test_bad_method(): return None - interpreter = taskgen.TraceInterpreter(ArchSpec()) + interpreter = taskgen.TraceInterpreter(arch_spec=ArchSpec()) with pytest.raises(ValueError): interpreter.run_trace(test_bad_method, (), {}) From 4571cf7bfa1950910665819f1632f9a0197fbe0c Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 4 Aug 2025 15:11:48 -0400 Subject: [PATCH 5/6] refactoring trace --- src/bloqade/shuttle/codegen/taskgen.py | 72 ----------------- .../shuttle/dialects/action/__init__.py | 1 + src/bloqade/shuttle/dialects/action/trace.py | 78 +++++++++++++++++++ src/bloqade/shuttle/dialects/tracking/aod.py | 21 +++++ 4 files changed, 100 insertions(+), 72 deletions(-) create mode 100644 src/bloqade/shuttle/dialects/action/trace.py create mode 100644 src/bloqade/shuttle/dialects/tracking/aod.py diff --git a/src/bloqade/shuttle/codegen/taskgen.py b/src/bloqade/shuttle/codegen/taskgen.py index 663de16d..9ce97c22 100644 --- a/src/bloqade/shuttle/codegen/taskgen.py +++ b/src/bloqade/shuttle/codegen/taskgen.py @@ -6,7 +6,6 @@ from bloqade.geometry.dialects import grid from kirin import ir from kirin.dialects import func, ilist -from kirin.interp import Frame, InterpreterError, MethodTable, impl from kirin.ir.method import Method from typing_extensions import Self @@ -152,74 +151,3 @@ def run_trace( # TODO: use permute_values to get correct order. super().run(mt, args=args, kwargs=kwargs) return self.trace.copy() - - -@action.dialect.register(key="action.tracer") -class ActionTracer(MethodTable): - - intensity_actions = { - action.TurnOnXY: TurnOnXYAction, - action.TurnOffXY: TurnOffXYAction, - action.TurnOnXSlice: TurnOnXSliceAction, - action.TurnOffXSlice: TurnOffXSliceAction, - action.TurnOnYSlice: TurnOnYSliceAction, - action.TurnOffYSlice: TurnOffYSliceAction, - action.TurnOnXYSlice: TurnOnXYSliceAction, - action.TurnOffXYSlice: TurnOffXYSliceAction, - } - - @impl(action.TurnOnXY) - @impl(action.TurnOffXY) - @impl(action.TurnOnXSlice) - @impl(action.TurnOffXSlice) - @impl(action.TurnOnYSlice) - @impl(action.TurnOffYSlice) - @impl(action.TurnOnXYSlice) - @impl(action.TurnOffXYSlice) - def construct_intensity_actions( - self, - interp: TraceInterpreter, - frame: Frame, - stmt: action.IntensityStatement, - ): - if interp.curr_pos is None: - raise InterpreterError( - "Position of AOD not set before turning on/off tones" - ) - - x_tone_indices = frame.get(stmt.x_tones) - y_tone_indices = frame.get(stmt.y_tones) - - interp.trace.append( - self.intensity_actions[type(stmt)]( - x_tone_indices if isinstance(x_tone_indices, slice) else x_tone_indices, - y_tone_indices if isinstance(y_tone_indices, slice) else y_tone_indices, - ) - ) - interp.trace.append(WayPointsAction(way_points=[interp.curr_pos])) - return () - - @impl(action.Move) - def move(self, interp: TraceInterpreter, frame: Frame, stmt: action.Move): - if interp.curr_pos is None: - raise InterpreterError("Position of AOD not set before moving tones") - - assert isinstance(interp.trace[-1], WayPointsAction) - - interp.trace[-1].add_waypoint(pos := frame.get_typed(stmt.grid, grid.Grid)) - if interp.curr_pos.shape != pos.shape: - raise InterpreterError( - f"Position of AOD {interp.curr_pos} and target position {pos} have different shapes" - ) - interp.curr_pos = pos - - return () - - @impl(action.Set) - def set(self, interp: TraceInterpreter, frame: Frame, stmt: action.Set): - pos = frame.get_typed(stmt.grid, grid.Grid) - interp.trace.append(WayPointsAction([pos])) - - interp.curr_pos = pos - - return () diff --git a/src/bloqade/shuttle/dialects/action/__init__.py b/src/bloqade/shuttle/dialects/action/__init__.py index abe070e1..a88f7c06 100644 --- a/src/bloqade/shuttle/dialects/action/__init__.py +++ b/src/bloqade/shuttle/dialects/action/__init__.py @@ -22,3 +22,4 @@ TurnOnYSlice as TurnOnYSlice, TweezerFunction as TweezerFunction, ) +from .trace import ActionTracer as ActionTracer diff --git a/src/bloqade/shuttle/dialects/action/trace.py b/src/bloqade/shuttle/dialects/action/trace.py new file mode 100644 index 00000000..aef1e72f --- /dev/null +++ b/src/bloqade/shuttle/dialects/action/trace.py @@ -0,0 +1,78 @@ +from bloqade.geometry.dialects import grid +from kirin.interp import Frame, InterpreterError, MethodTable, impl + +from bloqade.shuttle.codegen import taskgen + +from . import stmts +from ._dialect import dialect + + +@dialect.register(key="action.tracer") +class ActionTracer(MethodTable): + + intensity_actions = { + stmts.TurnOnXY: taskgen.TurnOnXYAction, + stmts.TurnOffXY: taskgen.TurnOffXYAction, + stmts.TurnOnXSlice: taskgen.TurnOnXSliceAction, + stmts.TurnOffXSlice: taskgen.TurnOffXSliceAction, + stmts.TurnOnYSlice: taskgen.TurnOnYSliceAction, + stmts.TurnOffYSlice: taskgen.TurnOffYSliceAction, + stmts.TurnOnXYSlice: taskgen.TurnOnXYSliceAction, + stmts.TurnOffXYSlice: taskgen.TurnOffXYSliceAction, + } + + @impl(stmts.TurnOnXY) + @impl(stmts.TurnOffXY) + @impl(stmts.TurnOnXSlice) + @impl(stmts.TurnOffXSlice) + @impl(stmts.TurnOnYSlice) + @impl(stmts.TurnOffYSlice) + @impl(stmts.TurnOnXYSlice) + @impl(stmts.TurnOffXYSlice) + def construct_intensity_actions( + self, + interp: taskgen.TraceInterpreter, + frame: Frame, + stmt: stmts.IntensityStatement, + ): + if interp.curr_pos is None: + raise InterpreterError( + "Position of AOD not set before turning on/off tones" + ) + + x_tone_indices = frame.get(stmt.x_tones) + y_tone_indices = frame.get(stmt.y_tones) + + interp.trace.append( + self.intensity_actions[type(stmt)]( + x_tone_indices if isinstance(x_tone_indices, slice) else x_tone_indices, + y_tone_indices if isinstance(y_tone_indices, slice) else y_tone_indices, + ) + ) + interp.trace.append(taskgen.WayPointsAction(way_points=[interp.curr_pos])) + return () + + @impl(stmts.Move) + def move(self, interp: taskgen.TraceInterpreter, frame: Frame, stmt: stmts.Move): + if interp.curr_pos is None: + raise InterpreterError("Position of AOD not set before moving tones") + + assert isinstance(interp.trace[-1], taskgen.WayPointsAction) + + interp.trace[-1].add_waypoint(pos := frame.get_typed(stmt.grid, grid.Grid)) + if interp.curr_pos.shape != pos.shape: + raise InterpreterError( + f"Position of AOD {interp.curr_pos} and target position {pos} have different shapes" + ) + interp.curr_pos = pos + + return () + + @impl(stmts.Set) + def set(self, interp: taskgen.TraceInterpreter, frame: Frame, stmt: stmts.Set): + pos = frame.get_typed(stmt.grid, grid.Grid) + interp.trace.append(taskgen.WayPointsAction([pos])) + + interp.curr_pos = pos + + return () diff --git a/src/bloqade/shuttle/dialects/tracking/aod.py b/src/bloqade/shuttle/dialects/tracking/aod.py new file mode 100644 index 00000000..de554f56 --- /dev/null +++ b/src/bloqade/shuttle/dialects/tracking/aod.py @@ -0,0 +1,21 @@ +# from kirin.analysis import ForwardFrame +# from kirin import interp +# from bloqade.shuttle.analysis import aod +# from ._dialect import dialect +# from .stmts import Play, GlobalR, GlobalRz, LocalR, LocalRz, Measure, Fill + +# from ..path import Path + +# @dialect.register(key="aod.analysis") +# class TrackingMethods(interp.MethodTable): + +# @interp.impl(Play) +# def play(self, _interp: aod.AODAnalysis, frame: ForwardFrame[aod.AODState], stmt: Play): + +# state = frame.get(stmt.state) +# if not isinstance(state, aod.AOD): +# return (state,) + +# path = _interp.get_const_value(Path, stmt.path) +# for action in path.path: +# match From f7a5d08af088f13b85ec49dd87b3d9ae4a824883 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 25 Sep 2025 14:55:09 -0400 Subject: [PATCH 6/6] WIP: refactor rewrite to use runtime analysis --- src/bloqade/shuttle/analysis/runtime.py | 23 ++++---- src/bloqade/shuttle/dialects/gate/runtime.py | 2 +- src/bloqade/shuttle/dialects/init/runtime.py | 2 +- .../shuttle/dialects/measure/runtime.py | 2 +- src/bloqade/shuttle/dialects/path/runtime.py | 2 +- src/bloqade/shuttle/passes/path2tracking.py | 54 +++++++++++-------- 6 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/bloqade/shuttle/analysis/runtime.py b/src/bloqade/shuttle/analysis/runtime.py index 8adaac05..10a3d6d1 100644 --- a/src/bloqade/shuttle/analysis/runtime.py +++ b/src/bloqade/shuttle/analysis/runtime.py @@ -12,11 +12,17 @@ class RuntimeFrame(ForwardFrame[EmptyLattice]): This frame is used to track the state of quantum operations within a method. """ - quantum_stmts: set[ir.Statement] = field(default_factory=set) + quantum_call: set[ir.Statement] = field(default_factory=set) """Set of quantum statements in the frame.""" is_quantum: bool = False """Whether the frame contains quantum operations.""" + def merge_runtime(self, other: "RuntimeFrame", stmt: ir.Statement): + if other.is_quantum: + self.is_quantum = True + self.quantum_call.add(stmt) + self.quantum_call.update(other.quantum_call) + class RuntimeAnalysis(ForwardExtra[RuntimeFrame, EmptyLattice]): """Forward dataflow analysis to check if a method has quantum runtime. @@ -61,10 +67,8 @@ def ifelse(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.IfElse else_frame, stmt.else_body, (_interp.lattice.top(),) ) - frame.is_quantum = ( - frame.is_quantum or then_frame.is_quantum or else_frame.is_quantum - ) - frame.quantum_stmts.update(then_frame.quantum_stmts, else_frame.quantum_stmts) + frame.merge_runtime(then_frame, stmt) + frame.merge_runtime(else_frame, stmt) match (then_result, else_result): case (interp.ReturnValue(), tuple()): return else_result @@ -86,8 +90,7 @@ def for_loop(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.For) body_frame, stmt.body, (_interp.lattice.bottom(),) ) - frame.is_quantum = frame.is_quantum or body_frame.is_quantum - frame.quantum_stmts.update(body_frame.quantum_stmts) + frame.merge_runtime(body_frame, stmt) if isinstance(result, interp.ReturnValue) or result is None: return args[1:] else: @@ -107,7 +110,8 @@ class Func(interp.MethodTable): def invoke(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Invoke): args = (_interp.lattice.top(),) * len(stmt.inputs) callee_frame, result = _interp.run_method(stmt.callee, args) - frame.is_quantum = frame.is_quantum or callee_frame.is_quantum + frame.merge_runtime(callee_frame, stmt) + return (result,) @interp.impl(func.Call) @@ -123,10 +127,11 @@ def call(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Call): body = trait.get_callable_region(callee_result.code) with _interp.new_frame(stmt) as callee_frame: result = _interp.run_ssacfg_region(callee_frame, body, args) + else: raise InterruptedError("Dynamic method calls are not supported") - frame.is_quantum = frame.is_quantum or callee_frame.is_quantum + frame.merge_runtime(callee_frame, stmt) return (result,) @interp.impl(func.Return) diff --git a/src/bloqade/shuttle/dialects/gate/runtime.py b/src/bloqade/shuttle/dialects/gate/runtime.py index 374819e8..578fc1f3 100644 --- a/src/bloqade/shuttle/dialects/gate/runtime.py +++ b/src/bloqade/shuttle/dialects/gate/runtime.py @@ -25,5 +25,5 @@ def gate( ) -> interp.StatementResult[RuntimeFrame]: """Handle gate statements and mark the frame as quantum.""" frame.is_quantum = True - frame.quantum_stmts.add(stmt) + frame.quantum_call.add(stmt) return () diff --git a/src/bloqade/shuttle/dialects/init/runtime.py b/src/bloqade/shuttle/dialects/init/runtime.py index bfe9c84b..70706e32 100644 --- a/src/bloqade/shuttle/dialects/init/runtime.py +++ b/src/bloqade/shuttle/dialects/init/runtime.py @@ -16,5 +16,5 @@ class HasQuantumRuntimeMethodTable(interp.MethodTable): def gate(self, interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Fill): """Handle gate statements and mark the frame as quantum.""" frame.is_quantum = True - frame.quantum_stmts.add(stmt) + frame.quantum_call.add(stmt) return () diff --git a/src/bloqade/shuttle/dialects/measure/runtime.py b/src/bloqade/shuttle/dialects/measure/runtime.py index e358ad5a..6eb64fa5 100644 --- a/src/bloqade/shuttle/dialects/measure/runtime.py +++ b/src/bloqade/shuttle/dialects/measure/runtime.py @@ -16,5 +16,5 @@ class HasQuantumRuntimeMethodTable(interp.MethodTable): def gate(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Measure): """Handle gate statements and mark the frame as quantum.""" frame.is_quantum = True - frame.quantum_stmts.add(stmt) + frame.quantum_call.add(stmt) return (_interp.lattice.top(),) diff --git a/src/bloqade/shuttle/dialects/path/runtime.py b/src/bloqade/shuttle/dialects/path/runtime.py index aff47956..ff7a53db 100644 --- a/src/bloqade/shuttle/dialects/path/runtime.py +++ b/src/bloqade/shuttle/dialects/path/runtime.py @@ -18,5 +18,5 @@ def gate( ) -> interp.StatementResult[RuntimeFrame]: """Handle gate statements and mark the frame as quantum.""" frame.is_quantum = True - frame.quantum_stmts.add(stmt) + frame.quantum_call.add(stmt) return () diff --git a/src/bloqade/shuttle/passes/path2tracking.py b/src/bloqade/shuttle/passes/path2tracking.py index 82b5043e..e7a881b9 100644 --- a/src/bloqade/shuttle/passes/path2tracking.py +++ b/src/bloqade/shuttle/passes/path2tracking.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field +from bloqade.shuttle.analysis.runtime import RuntimeFrame, RuntimeAnalysis from kirin import ir, rewrite, types from kirin.dialects import cf, func, scf from kirin.passes import Pass @@ -12,8 +13,9 @@ @dataclass class PathToTrackingRewrite(RewriteRule): entry_code: func.Function - curr_state: ir.SSAValue | None = None - call_graph: dict[ir.Method, ir.Method] = field(default_factory=dict) + runtime_frame: RuntimeFrame + curr_state: ir.SSAValue | None = field(default=None, init=False) + call_graph: dict[ir.Method, ir.Method] = field(default_factory=dict, init=False) stmt_types = ( path.Play, @@ -57,9 +59,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return method(node) def rewrite_Fill(self, node: init.Fill) -> RewriteResult: - assert ( - self.curr_state is None - ), "curr_state should not be set before Fill is rewritten" + if self.curr_state is not None: + return RewriteResult() node.replace_by(new_node := tracking.Fill(node.locations)) @@ -68,9 +69,8 @@ def rewrite_Fill(self, node: init.Fill) -> RewriteResult: return RewriteResult(has_done_something=True) def rewrite_Play(self, node: path.Play) -> RewriteResult: - assert ( - self.curr_state is not None - ), "curr_state should be set before Play is rewritten" + if self.curr_state is None: + return RewriteResult() node.replace_by( new_node := tracking.Play(state=self.curr_state, path=node.path) @@ -79,9 +79,8 @@ def rewrite_Play(self, node: path.Play) -> RewriteResult: return RewriteResult(has_done_something=True) def rewrite_TopHatCZ(self, node: gate.TopHatCZ) -> RewriteResult: - assert ( - self.curr_state is not None - ), "curr_state should be set before TopHatCZ is rewritten" + if self.curr_state is None: + return RewriteResult() node.replace_by( new_node := tracking.TopHatCZ(state=self.curr_state, zone=node.zone) @@ -91,9 +90,8 @@ def rewrite_TopHatCZ(self, node: gate.TopHatCZ) -> RewriteResult: return RewriteResult(has_done_something=True) def rewrite_GlobalR(self, node: gate.GlobalR) -> RewriteResult: - assert ( - self.curr_state is not None - ), "curr_state should be set before GlobalR is rewritten" + if self.curr_state is None: + return RewriteResult() node.replace_by( new_node := tracking.GlobalR( @@ -107,9 +105,8 @@ def rewrite_GlobalR(self, node: gate.GlobalR) -> RewriteResult: return RewriteResult(has_done_something=True) def rewrite_LocalR(self, node: gate.LocalR) -> RewriteResult: - assert ( - self.curr_state is not None - ), "curr_state should be set before LocalR is rewritten" + if self.curr_state is None: + return RewriteResult() node.replace_by( new_node := tracking.LocalR( @@ -260,11 +257,14 @@ def rewrite_Invoke(self, node: func.Invoke) -> RewriteResult: callee = node.callee if callee not in self.call_graph: - new_callee = callee.similar() - self.call_graph[callee] = new_callee - new_callee.arg_names = ["system_state", *callee.arg_names] - - rewrite.Walk(self).rewrite(new_callee.code) + if node in self.runtime_frame.quantum_call: + new_callee = callee.similar() + self.call_graph[callee] = new_callee + new_callee.arg_names = ["system_state", *callee.arg_names] + rewrite.Walk(self).rewrite(new_callee.code) + else: + self.call_graph[callee] = callee + new_callee = callee else: new_callee = self.call_graph[callee] @@ -281,8 +281,16 @@ def rewrite_Invoke(self, node: func.Invoke) -> RewriteResult: @dataclass class PathToTracking(Pass): + runtime: RuntimeAnalysis = field(init=False) + + def __post_init__(self): + self.runtime = RuntimeAnalysis(self.dialects) + def unsafe_run(self, mt: ir.Method) -> RewriteResult: if not isinstance(mt.code, func.Function): return RewriteResult() - return rewrite.Walk(PathToTrackingRewrite(mt.code)).rewrite(mt.code) + runtime_frame, _ = self.runtime.run_analysis(mt) + return rewrite.Walk(PathToTrackingRewrite(mt.code, runtime_frame)).rewrite( + mt.code + )