Skip to content
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
QNodes.
[(#8731)](https://github.com/PennyLaneAI/pennylane/pull/8731)

* :class:`~.transforms.core.TransformContainer` has been renamed to :class:`~.transforms.core.BoundTransform`.
The old name is still available in the same location.
[(#8753)](https://github.com/PennyLaneAI/pennylane/pull/8753)
* The :class:`~.CompilePipeline` (previously known as the `TransformProgram`) can now be constructed
more flexibility with a variable number of arguments that are of types `TransformDispatcher`,
`TransformContainer`, or other `CompilePipeline`s.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/noise/add_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pennylane.operation import DecompositionUndefinedError, Operator
from pennylane.ops import Adjoint
from pennylane.tape import make_qscript
from pennylane.transforms.core import TransformContainer, transform
from pennylane.transforms.core import BoundTransform, transform
from pennylane.workflow import get_transform_program

from .conditionals import partial_wires
Expand Down Expand Up @@ -275,7 +275,7 @@ def custom_qnode_wrapper(self, qnode, targs, tkwargs):

cqnode._transform_program = compile_pipeline
cqnode.transform_program.push_back(
TransformContainer(
BoundTransform(
self,
targs,
{**tkwargs},
Expand Down
3 changes: 2 additions & 1 deletion pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@
Transforms developer functions
------------------------------

:class:`~.TransformContainer` and :class:`~.TransformDispatcher` are developer-facing objects that allow the
:class:`~.TransformDispatcher` is a
developer-facing objects that allow the
creation, dispatching, and composability of transforms. If you would like to make a custom transform, refer
instead to the documentation of :func:`qml.transform <pennylane.transform>`.

Expand Down
7 changes: 6 additions & 1 deletion pennylane/transforms/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
r"""This module contains the experimental transforms building blocks (core)."""

from .transform import transform
from .transform_dispatcher import TransformDispatcher, TransformContainer, TransformError
from .transform_dispatcher import (
TransformDispatcher,
TransformContainer,
TransformError,
BoundTransform,
)
from .compile_pipeline import CompilePipeline


Expand Down
72 changes: 36 additions & 36 deletions pennylane/transforms/core/compile_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pennylane.typing import BatchPostprocessingFn, PostprocessingFn, ResultBatch

from .cotransform_cache import CotransformCache
from .transform_dispatcher import TransformContainer, TransformDispatcher
from .transform_dispatcher import BoundTransform, TransformDispatcher

if TYPE_CHECKING:
import jax
Expand Down Expand Up @@ -114,7 +114,7 @@ class CompilePipeline:
The order of execution is the order in the list containing the containers.

Args:
initial_program (Optional[Sequence[TransformContainer]]): A sequence of transforms with
initial_program (Optional[Sequence[BoundTransform]]): A sequence of transforms with
which to initialize the program.
cotransform_cache (Optional[CotransformCache]): A named tuple containing the ``qnode``,
``args``, and ``kwargs`` required to compute classical cotransforms.
Expand Down Expand Up @@ -167,23 +167,23 @@ class CompilePipeline:
@overload
def __init__(
self,
transforms: Sequence[TransformContainer],
transforms: Sequence[BoundTransform],
/,
*,
cotransform_cache: CotransformCache | None = None,
): ...
@overload
def __init__(
self,
*transforms: CompilePipeline | TransformContainer | TransformDispatcher,
*transforms: CompilePipeline | BoundTransform | TransformDispatcher,
cotransform_cache: CotransformCache | None = None,
): ...
def __init__(
self,
*transforms: CompilePipeline
| TransformContainer
| BoundTransform
| TransformDispatcher
| Sequence[TransformContainer],
| Sequence[BoundTransform],
cotransform_cache: CotransformCache | None = None,
):
if len(transforms) == 1 and isinstance(transforms[0], Sequence):
Expand All @@ -194,7 +194,7 @@ def __init__(
self._compile_pipeline = []
self.cotransform_cache = cotransform_cache
for obj in transforms:
if not isinstance(obj, (CompilePipeline, TransformContainer, TransformDispatcher)):
if not isinstance(obj, (CompilePipeline, BoundTransform, TransformDispatcher)):
raise TypeError(
"CompilePipeline can only be constructed with a series of transforms "
"or compile pipelines, or with a single list of transforms."
Expand All @@ -205,19 +205,19 @@ def __copy__(self):
return CompilePipeline(self._compile_pipeline, cotransform_cache=self.cotransform_cache)

def __iter__(self):
"""list[TransformContainer]: Return an iterator to the underlying compile pipeline."""
"""list[BoundTransform]: Return an iterator to the underlying compile pipeline."""
return self._compile_pipeline.__iter__()

def __len__(self) -> int:
"""int: Return the number transforms in the program."""
return len(self._compile_pipeline)

@overload
def __getitem__(self, idx: int) -> TransformContainer: ...
def __getitem__(self, idx: int) -> BoundTransform: ...
@overload
def __getitem__(self, idx: slice) -> CompilePipeline: ...
def __getitem__(self, idx):
"""(TransformContainer, List[TransformContainer]): Return the indexed transform container from underlying
"""(BoundTransform, List[BoundTransform]): Return the indexed transform container from underlying
compile pipeline"""
if isinstance(idx, slice):
return CompilePipeline(self._compile_pipeline[idx])
Expand All @@ -227,15 +227,15 @@ def __bool__(self) -> bool:
return bool(self._compile_pipeline)

def __add__(
self, other: CompilePipeline | TransformContainer | TransformDispatcher
self, other: CompilePipeline | BoundTransform | TransformDispatcher
) -> CompilePipeline:

# Convert dispatcher to container if needed
if isinstance(other, TransformDispatcher):
other = TransformContainer(other)
other = BoundTransform(other)

# Handle TransformContainer
if isinstance(other, TransformContainer):
# Handle BoundTransform
if isinstance(other, BoundTransform):
other = CompilePipeline([other])

# Handle CompilePipeline
Expand All @@ -258,16 +258,16 @@ def __add__(

return NotImplemented

def __radd__(self, other: TransformContainer | TransformDispatcher) -> CompilePipeline:
def __radd__(self, other: BoundTransform | TransformDispatcher) -> CompilePipeline:
"""Right addition to prepend a transform to the program.

Args:
other: A TransformContainer or TransformDispatcher to prepend.
other: A BoundTransform or TransformDispatcher to prepend.

Returns:
CompilePipeline: A new program with the transform prepended.
"""
if isinstance(other, TransformContainer):
if isinstance(other, BoundTransform):
if self.has_final_transform and other.final_transform:
raise TransformError("The compile pipeline already has a terminal transform.")

Expand All @@ -277,21 +277,21 @@ def __radd__(self, other: TransformContainer | TransformDispatcher) -> CompilePi
return NotImplemented

def __iadd__(
self, other: CompilePipeline | TransformContainer | TransformDispatcher
self, other: CompilePipeline | BoundTransform | TransformDispatcher
) -> CompilePipeline:
"""In-place addition to append a transform to the program.

Args:
other: A TransformContainer, TransformDispatcher, or CompilePipeline to append.
other: A BoundTransform, TransformDispatcher, or CompilePipeline to append.

Returns:
CompilePipeline: This program with the transform(s) appended.
"""
# Convert dispatcher to container if needed
if isinstance(other, TransformDispatcher):
other = TransformContainer(other)
other = BoundTransform(other)

if isinstance(other, TransformContainer):
if isinstance(other, BoundTransform):
other = CompilePipeline([other])

if isinstance(other, CompilePipeline):
Expand Down Expand Up @@ -353,19 +353,19 @@ def __eq__(self, other) -> bool:
return self._compile_pipeline == other._compile_pipeline

def __contains__(self, obj) -> bool:
if isinstance(obj, TransformContainer):
if isinstance(obj, BoundTransform):
return obj in self._compile_pipeline
if isinstance(obj, TransformDispatcher):
return any(obj.transform == t.transform for t in self)
return False

def push_back(self, transform_container: TransformContainer):
def push_back(self, transform_container: BoundTransform):
"""Add a transform (container) to the end of the program.

Args:
transform_container(TransformContainer): A transform represented by its container.
transform_container(BoundTransform): A transform represented by its container.
"""
if not isinstance(transform_container, TransformContainer):
if not isinstance(transform_container, BoundTransform):
raise TransformError("Only transform container can be added to the compile pipeline.")

# Program can only contain one informative transform and at the end of the program
Expand All @@ -376,11 +376,11 @@ def push_back(self, transform_container: TransformContainer):
return
self._compile_pipeline.append(transform_container)

def insert_front(self, transform_container: TransformContainer):
def insert_front(self, transform_container: BoundTransform):
"""Insert the transform container at the beginning of the program.

Args:
transform_container(TransformContainer): A transform represented by its container.
transform_container(BoundTransform): A transform represented by its container.
"""
if (transform_container.final_transform) and not self.is_empty():
raise TransformError(
Expand All @@ -392,7 +392,7 @@ def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the end of the program.

Note that this should be a function decorated with/called by
``qml.transforms.transform``, and not a ``TransformContainer``.
``qml.transforms.transform``, and not a ``BoundTransform``.

Args:
transform (TransformDispatcher): The transform to add to the compile pipeline.
Expand All @@ -407,10 +407,10 @@ def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):

if transform.expand_transform:
self.push_back(
TransformContainer(TransformDispatcher(transform.expand_transform), targs, tkwargs)
BoundTransform(TransformDispatcher(transform.expand_transform), targs, tkwargs)
)
self.push_back(
TransformContainer(
BoundTransform(
transform,
args=targs,
kwargs=tkwargs,
Expand All @@ -434,7 +434,7 @@ def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwar
)

self.insert_front(
TransformContainer(
BoundTransform(
transform,
args=targs,
kwargs=tkwargs,
Expand All @@ -443,22 +443,22 @@ def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwar

if transform.expand_transform:
self.insert_front(
TransformContainer(TransformDispatcher(transform.expand_transform), targs, tkwargs)
BoundTransform(TransformDispatcher(transform.expand_transform), targs, tkwargs)
)

def pop_front(self):
"""Pop the transform container at the beginning of the program.

Returns:
TransformContainer: The transform container at the beginning of the program.
BoundTransform: The transform container at the beginning of the program.
"""
return self._compile_pipeline.pop(0)

def get_last(self):
"""Get the last transform container.

Returns:
TransformContainer: The last transform in the program.
BoundTransform: The last transform in the program.

Raises:
TransformError: It raises an error if the program is empty.
Expand Down Expand Up @@ -671,15 +671,15 @@ def _apply_to_program(obj: CompilePipeline, transform, *targs, **tkwargs):
if transform.expand_transform:
# pylint: disable=protected-access
program.push_back(
TransformContainer(
BoundTransform(
transform.expand_transform,
targs,
tkwargs,
use_argnum=transform._use_argnum_in_expand,
)
)
program.push_back(
TransformContainer(
BoundTransform(
transform,
args=targs,
kwargs=tkwargs,
Expand Down
6 changes: 3 additions & 3 deletions pennylane/transforms/core/cotransform_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pennylane.exceptions import QuantumFunctionError
from pennylane.typing import TensorLike

from .transform_dispatcher import TransformContainer
from .transform_dispatcher import BoundTransform


def _numpy_jac(*_, **__) -> TensorLike:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _get_idx_for_transform(self, transform):
return i
raise ValueError(f"Could not find {transform} in qnode's transform program.")

def get_classical_jacobian(self, transform: TransformContainer, tape_idx: int):
def get_classical_jacobian(self, transform: BoundTransform, tape_idx: int):
"""Calculate the classical jacobian for a given transform.

Note that this function assumes that the transform exists at most one in the compile pipeline.
Expand Down Expand Up @@ -200,7 +200,7 @@ def c(x, y):
classical_jacobian = _jac_map[interface](f, argnums, *self.args, **self.kwargs)
return classical_jacobian

def get_argnums(self, transform: TransformContainer) -> list[set[int]] | None:
def get_argnums(self, transform: BoundTransform) -> list[set[int]] | None:
"""Calculate the trainable params from the argnums in the transform.

.. code-block:: python
Expand Down
Loading