From 4d96c47cad499ba5dbf7bcdbc7196cb7e73712d6 Mon Sep 17 00:00:00 2001
From: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Date: Thu, 24 Apr 2025 19:43:29 +0200
Subject: [PATCH 1/2] Remove strict=False in hot loops

This is actually slower than just not specifying it
---
 pyproject.toml                                |  6 +-----
 pytensor/compile/builders.py                  |  5 ++---
 pytensor/link/basic.py                        | 19 ++++++++++++-------
 pytensor/link/numba/dispatch/basic.py         |  4 ++--
 .../link/numba/dispatch/cython_support.py     |  5 +----
 pytensor/link/numba/dispatch/extra_ops.py     |  2 +-
 pytensor/link/numba/dispatch/slinalg.py       |  2 +-
 pytensor/link/numba/dispatch/subtensor.py     | 10 +++++-----
 pytensor/link/utils.py                        |  8 ++++----
 pytensor/scalar/basic.py                      |  4 ++--
 pytensor/scalar/loop.py                       |  4 ++--
 pytensor/tensor/random/basic.py               |  4 ++--
 pytensor/tensor/random/utils.py               | 11 +++++------
 pytensor/tensor/shape.py                      |  6 ++----
 pytensor/tensor/type.py                       |  4 ++--
 15 files changed, 44 insertions(+), 50 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index bbb64549e5..41169f38c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -130,12 +130,8 @@ exclude = ["doc/", "pytensor/_version.py"]
 docstring-code-format = true
 
 [tool.ruff.lint]
-select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
+select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
 ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
-unfixable = [
-    # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
-    "B905",
-]
 
 
 [tool.ruff.lint.isort]
diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py
index a4a3d1840a..b669780607 100644
--- a/pytensor/compile/builders.py
+++ b/pytensor/compile/builders.py
@@ -873,7 +873,6 @@ def clone(self):
 
     def perform(self, node, inputs, outputs):
         variables = self.fn(*inputs)
-        assert len(variables) == len(outputs)
-        # strict=False because asserted above
-        for output, variable in zip(outputs, variables, strict=False):
+        # strict=None because we are in a hot loop
+        for output, variable in zip(outputs, variables):
             output[0] = variable
diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py
index 9cf34983f2..5597ddddfb 100644
--- a/pytensor/link/basic.py
+++ b/pytensor/link/basic.py
@@ -373,7 +373,12 @@ def make_all(
 
         # The function that actually runs your program is one of the f's in streamline.
         f = streamline(
-            fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling
+            fgraph,
+            thunks,
+            order,
+            post_thunk_old_storage=post_thunk_old_storage,
+            no_recycling=no_recycling,
+            output_storage=output_storage,
         )
 
         f.allow_gc = (
@@ -539,14 +544,14 @@ def make_thunk(self, **kwargs):
 
         def f():
             for inputs in input_lists[1:]:
-                # strict=False because we are in a hot loop
-                for input1, input2 in zip(inputs0, inputs, strict=False):
+                # strict=None because we are in a hot loop
+                for input1, input2 in zip(inputs0, inputs):
                     input2.storage[0] = copy(input1.storage[0])
             for x in to_reset:
                 x[0] = None
             pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
-            # strict=False because we are in a hot loop
-            for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
+            # strict=None because we are in a hot loop
+            for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
                 try:
                     wrapper(self.fgraph, i, node, *thunks)
                 except Exception:
@@ -668,8 +673,8 @@ def thunk(
                 #  since the error may come from any of them?
                 raise_with_op(self.fgraph, output_nodes[0], thunk)
 
-            # strict=False because we are in a hot loop
-            for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
+            # strict=None because we are in a hot loop
+            for o_storage, o_val in zip(thunk_outputs, outputs):
                 o_storage[0] = o_val
 
         thunk.inputs = thunk_inputs
diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py
index a6a82ceebe..6a85bb77f1 100644
--- a/pytensor/link/numba/dispatch/basic.py
+++ b/pytensor/link/numba/dispatch/basic.py
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
     else:
 
         def py_perform_return(inputs):
-            # strict=False because we are in a hot loop
+            # strict=None because we are in a hot loop
             return tuple(
                 out_type.filter(out[0])
-                for out_type, out in zip(output_types, py_perform(inputs), strict=False)
+                for out_type, out in zip(output_types, py_perform(inputs))
             )
 
     @numba_njit
diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py
index 8dccf98836..422e4be406 100644
--- a/pytensor/link/numba/dispatch/cython_support.py
+++ b/pytensor/link/numba/dispatch/cython_support.py
@@ -166,10 +166,7 @@ def __wrapper_address__(self):
     def __call__(self, *args, **kwargs):
         # no strict argument because of the JIT
         # TODO: check
-        args = [
-            dtype(arg)
-            for arg, dtype in zip(args, self._signature.arg_dtypes)  # noqa: B905
-        ]
+        args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
         if self.has_pyx_skip_dispatch():
             output = self._pyfunc(*args[:-1], **kwargs)
         else:
diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py
index 1f0a33e595..f7700acf47 100644
--- a/pytensor/link/numba/dispatch/extra_ops.py
+++ b/pytensor/link/numba/dispatch/extra_ops.py
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
             new_arr = arr.T.astype(np.float64).copy()
             for i, b in enumerate(new_arr):
                 # no strict argument to this zip because numba doesn't support it
-                for j, (d, v) in enumerate(zip(shape, b)):  # noqa: B905
+                for j, (d, v) in enumerate(zip(shape, b)):
                     if v < 0 or v >= d:
                         mode_fn(new_arr, i, j, v, d)
 
diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py
index 7e1f6ded56..92f8a254f8 100644
--- a/pytensor/link/numba/dispatch/slinalg.py
+++ b/pytensor/link/numba/dispatch/slinalg.py
@@ -183,7 +183,7 @@ def block_diag(*arrs):
 
         r, c = 0, 0
         # no strict argument because it is incompatible with numba
-        for arr, shape in zip(arrs, shapes):  # noqa: B905
+        for arr, shape in zip(arrs, shapes):
             rr, cc = shape
             out[r : r + rr, c : c + cc] = arr
             r += rr
diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py
index ee9e183d16..5f471707a5 100644
--- a/pytensor/link/numba/dispatch/subtensor.py
+++ b/pytensor/link/numba/dispatch/subtensor.py
@@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs):
             shape_aft = x_shape[after_last_axis:]
             out_shape = (*shape_bef, *idx_shape, *shape_aft)
             out_buffer = np.empty(out_shape, dtype=x.dtype)
-            for i, scalar_idxs in enumerate(zip(*vec_idxs)):  # noqa: B905
+            for i, scalar_idxs in enumerate(zip(*vec_idxs)):
                 out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
             return out_buffer
 
@@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs):
                     y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
 
                 for outer in np.ndindex(x_shape[:first_axis]):
-                    for i, scalar_idxs in enumerate(zip(*vec_idxs)):  # noqa: B905
+                    for i, scalar_idxs in enumerate(zip(*vec_idxs)):
                         out[(*outer, *scalar_idxs)] = y[(*outer, i)]
                 return out
 
@@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
                     y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
 
                 for outer in np.ndindex(x_shape[:first_axis]):
-                    for i, scalar_idxs in enumerate(zip(*vec_idxs)):  # noqa: B905
+                    for i, scalar_idxs in enumerate(zip(*vec_idxs)):
                         out[(*outer, *scalar_idxs)] += y[(*outer, i)]
                 return out
 
@@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
                 if not len(idxs) == len(vals):
                     raise ValueError("The number of indices and values must match.")
                 # no strict argument because incompatible with numba
-                for idx, val in zip(idxs, vals):  # noqa: B905
+                for idx, val in zip(idxs, vals):
                     x[idx] = val
                 return x
     else:
@@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
                     raise ValueError("The number of indices and values must match.")
                 # no strict argument because unsupported by numba
                 # TODO: this doesn't come up in tests
-                for idx, val in zip(idxs, vals):  # noqa: B905
+                for idx, val in zip(idxs, vals):
                     x[idx] += val
                 return x
 
diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py
index 9cbc3838dd..019acdd0ca 100644
--- a/pytensor/link/utils.py
+++ b/pytensor/link/utils.py
@@ -190,9 +190,9 @@ def streamline_default_f():
             for x in no_recycling:
                 x[0] = None
             try:
-                # strict=False because we are in a hot loop
+                # strict=None because we are in a hot loop
                 for thunk, node, old_storage in zip(
-                    thunks, order, post_thunk_old_storage, strict=False
+                    thunks, order, post_thunk_old_storage
                 ):
                     thunk()
                     for old_s in old_storage:
@@ -207,8 +207,8 @@ def streamline_nice_errors_f():
             for x in no_recycling:
                 x[0] = None
             try:
-                # strict=False because we are in a hot loop
-                for thunk, node in zip(thunks, order, strict=False):
+                # strict=None because we are in a hot loop
+                for thunk, node in zip(thunks, order):
                     thunk()
             except Exception:
                 raise_with_op(fgraph, node, thunk)
diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py
index 909fc47c27..5349b7d2cf 100644
--- a/pytensor/scalar/basic.py
+++ b/pytensor/scalar/basic.py
@@ -4416,8 +4416,8 @@ def make_node(self, *inputs):
 
     def perform(self, node, inputs, output_storage):
         outputs = self.py_perform_fn(*inputs)
-        # strict=False because we are in a hot loop
-        for storage, out_val in zip(output_storage, outputs, strict=False):
+        # strict=None because we are in a hot loop
+        for storage, out_val in zip(output_storage, outputs):
             storage[0] = out_val
 
     def grad(self, inputs, output_grads):
diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py
index 1023e6a127..98b715cc0e 100644
--- a/pytensor/scalar/loop.py
+++ b/pytensor/scalar/loop.py
@@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage):
             for i in range(n_steps):
                 carry = inner_fn(*carry, *constant)
 
-        # strict=False because we are in a hot loop
-        for storage, out_val in zip(output_storage, carry, strict=False):
+        # strict=None because we are in a hot loop
+        for storage, out_val in zip(output_storage, carry):
             storage[0] = out_val
 
     @property
diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py
index 214a7bdd3d..6939e6b155 100644
--- a/pytensor/tensor/random/basic.py
+++ b/pytensor/tensor/random/basic.py
@@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size):
             # to `p.shape[:-1]` in the call to `vsearchsorted` below.
             if len(size) < (p.ndim - 1):
                 raise ValueError("`size` is incompatible with the shape of `p`")
-            # strict=False because we are in a hot loop
-            for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
+            # strict=None because we are in a hot loop
+            for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
                 if s == 1 and ps != 1:
                     raise ValueError("`size` is incompatible with the shape of `p`")
 
diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py
index 23b4b50265..c91745b60b 100644
--- a/pytensor/tensor/random/utils.py
+++ b/pytensor/tensor/random/utils.py
@@ -44,8 +44,8 @@ def params_broadcast_shapes(
     max_fn = maximum if use_pytensor else max
 
     rev_extra_dims: list[int] = []
-    # strict=False because we are in a hot loop
-    for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
+    # strict=None because we are in a hot loop
+    for ndim_param, param_shape in zip(ndims_params, param_shapes):
         # We need this in order to use `len`
         param_shape = tuple(param_shape)
         extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -69,7 +69,7 @@ def max_bcast(x, y):
         (extra_dims + tuple(param_shape)[-ndim_param:])
         if ndim_param > 0
         else extra_dims
-        for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
+        for ndim_param, param_shape in zip(ndims_params, param_shapes)
     ]
 
     return bcast_shapes
@@ -127,10 +127,9 @@ def broadcast_params(
     )
     broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
 
-    # strict=False because we are in a hot loop
+    # strict=None because we are in a hot loop
     bcast_params = [
-        broadcast_to_fn(param, shape)
-        for shape, param in zip(shapes, params, strict=False)
+        broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
     ]
 
     return bcast_params
diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py
index 5a4cfdc52a..f82bc7a39b 100644
--- a/pytensor/tensor/shape.py
+++ b/pytensor/tensor/shape.py
@@ -447,10 +447,8 @@ def perform(self, node, inp, out_):
             raise AssertionError(
                 f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
             )
-        # strict=False because we are in a hot loop
-        if not all(
-            xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
-        ):
+        # strict=None because we are in a hot loop
+        if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
             raise AssertionError(
                 f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
             )
diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py
index d0b6b5fe0a..9bb2f0a731 100644
--- a/pytensor/tensor/type.py
+++ b/pytensor/tensor/type.py
@@ -261,10 +261,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
                 " PyTensor C code does not support that.",
             )
 
-        # strict=False because we are in a hot loop
+        # strict=None because we are in a hot loop
         if not all(
             ds == ts if ts is not None else True
-            for ds, ts in zip(data.shape, self.shape, strict=False)
+            for ds, ts in zip(data.shape, self.shape)
         ):
             raise TypeError(
                 f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"

From 482cded01221d6ba0c496e2cb3d61b19cf7c7a98 Mon Sep 17 00:00:00 2001
From: ricardoV94 <ricardo.vieira1994@gmail.com>
Date: Tue, 8 Apr 2025 10:04:22 +0200
Subject: [PATCH 2/2] Remove deprecated PyTensor function functionality and
 reduce overhead

---
 doc/library/compile/io.rst              |  17 +-
 pytensor/compile/function/types.py      | 608 ++++++++-------------
 pytensor/compile/io.py                  |  10 +-
 pytensor/link/basic.py                  |   3 +
 pytensor/link/c/basic.py                |   4 +-
 pytensor/link/utils.py                  |   5 +
 pytensor/scan/op.py                     |  12 -
 pytensor/tensor/type.py                 |   8 +-
 tests/compile/function/test_function.py |   8 +-
 tests/compile/function/test_pfunc.py    |   2 +-
 tests/compile/function/test_types.py    | 675 ++++++++----------------
 tests/compile/test_debugmode.py         |  73 +--
 tests/link/test_vm.py                   |  69 +--
 13 files changed, 482 insertions(+), 1012 deletions(-)

diff --git a/doc/library/compile/io.rst b/doc/library/compile/io.rst
index 272d4754db..962d3ea7a5 100644
--- a/doc/library/compile/io.rst
+++ b/doc/library/compile/io.rst
@@ -35,12 +35,11 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var
       can be set by ``kwarg``, and its value can be accessed by
       ``self.<name>``. The default value is ``None``.
 
-      ``value``: literal or ``Container``. The initial/default value for this
+      ``value``: ``Container``. The initial value for this
         input. If update is ``None``, this input acts just like
         an argument with a default value in Python. If update is not ``None``,
-        changes to this
-        value will "stick around", whether due to an update or a user's
-        explicit action.
+        changes to this value will "stick around", whether due to an update
+        or a user's explicit action.
 
       ``update``: Variable instance. This expression Variable will
       replace ``value`` after each function call. The default value is
@@ -73,18 +72,16 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var
             overwriting its content without being aware of it).
 
 
-Value: initial and default values
----------------------------------
+Update
+------
 
-A non-None `value` argument makes an In() instance an optional parameter
-of the compiled function.  For example, in the following code we are
-defining an arity-2 function ``inc``.
+We can define an update to modify the value
 
 >>> import pytensor.tensor as pt
 >>> from pytensor import function
 >>> from pytensor.compile.io import In
 >>> u, x, s = pt.scalars('u', 'x', 's')
->>> inc = function([u, In(x, value=3), In(s, update=(s+x*u), value=10.0)], [])
+>>> inc = function([u, In(x), In(s, update=(s+x*u)], [])
 
 Since we provided a ``value`` for ``s`` and ``x``, we can call it with just a value for ``u`` like this:
 
diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py
index 9cc85f3d24..050c84f97c 100644
--- a/pytensor/compile/function/types.py
+++ b/pytensor/compile/function/types.py
@@ -331,23 +331,35 @@ class Function:
     ``False``. When ``True``, the `Function` will skip all checks on the
     inputs.
 
-    Attributes
-    ----------
-    finder
-        Dictionary mapping several kinds of things to containers.
-
-        We set an entry in finder for:
-        - the index of the input
-        - the variable instance the input is based on
-        - the name of the input
-
-        All entries map to the container or to DUPLICATE if an ambiguity
-        is detected.
-    inv_finder
-        Reverse lookup of `finder`.  It maps containers to `SymbolicInput`\s.
-
     """
 
+    __slots__ = (
+        "vm",
+        "input_storage",
+        "output_storage",
+        "indices",
+        "outputs",
+        "unpack_single",
+        "return_none",
+        "maker",
+        "profile",
+        "trust_input",
+        "name",
+        # Created inside __init__
+        "_potential_aliased_input_groups",
+        "_named_inputs",
+        "_n_unnamed_inputs",
+        "_finder",
+        "_inv_finder",
+        "_has_updates",
+        "_n_returned_outputs",
+        "_input_storage_data",
+        "_update_input_storage",
+        "_clear_input_storage_data",
+        "_clear_output_storage_data",
+        "_nodes_with_inner_function",
+    )
+
     pickle_aliased_memory_strategy = "warn"
     """
     How to deal with pickling finding aliased storage.
@@ -368,10 +380,8 @@ def __init__(
         output_storage: list[Container],
         indices,
         outputs,
-        defaults,
         unpack_single: bool,
         return_none: bool,
-        output_keys,
         maker: "FunctionMaker",
         trust_input: bool = False,
         name: str | None = None,
@@ -392,20 +402,11 @@ def __init__(
             tuple elements are used only by Kits, which are deprecated.
         outputs
             TODO
-        defaults
-            List of 3-tuples, one 3-tuple for each input.
-            Tuple element 0: ``bool``.  Is this input required at each function
-            call?
-            Tuple element 1: ``bool``.  Should this inputs value be reverted
-            after each call?
-            Tuple element 2: ``Any``.  The value associated with this input.
         unpack_single
             For outputs lists of length 1, should the 0'th element be
             returned directly?
         return_none
             Whether the function should return ``None`` or not.
-        output_keys
-            TODO
         maker
             The `FunctionMaker` that created this instance.
         trust_input : bool, default False
@@ -422,24 +423,16 @@ def __init__(
         self.output_storage = output_storage
         self.indices = indices
         self.outputs = outputs
-        self.defaults = defaults
         self.unpack_single = unpack_single
         self.return_none = return_none
         self.maker = maker
         self.profile = None  # reassigned in FunctionMaker.create
         self.trust_input = trust_input  # If True, we don't check the input parameter
         self.name = name
-        self.nodes_with_inner_function = []
-        self.output_keys = output_keys
-
-        if self.output_keys is not None:
-            warnings.warn("output_keys is deprecated.", FutureWarning)
 
         assert len(self.input_storage) == len(self.maker.fgraph.inputs)
         assert len(self.output_storage) == len(self.maker.fgraph.outputs)
 
-        self.has_defaults = any(refeed for _, refeed, _ in self.defaults)
-
         # Group indexes of inputs that are potentially aliased to each other
         # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
         #  even though there could be two distinct types that use the same kinds of underlying objects.
@@ -475,41 +468,20 @@ def __init__(
             if len(group) > 1
         )
 
-        # We will be popping stuff off this `containers` object.  It is a copy.
-        containers = list(self.input_storage)
-        finder = {}
-        inv_finder = {}
-
-        # Store the list of names of named inputs.
-        named_inputs = []
-        # Count the number of un-named inputs.
-        n_unnamed_inputs = 0
-
+        self._finder = finder = {}
+        self._inv_finder = inv_finder = {}
+        self._named_inputs = named_inputs = []
+        self._n_unnamed_inputs = 0
         # Initialize the storage
         # this loop works by modifying the elements (as variable c) of
         # self.input_storage inplace.
-        for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
-            zip(self.indices, defaults, strict=True)
-        ):
+        remaining_containers = self.input_storage.copy()
+        for i, (input, indices, sinputs) in enumerate(self.indices):
             if indices is None:
-                # containers is being used as a stack. Here we pop off
-                # the next one.
-                c = containers[0]
+                c = remaining_containers.pop(0)
                 c.strict = getattr(input, "strict", False)
                 c.allow_downcast = getattr(input, "allow_downcast", None)
-
-                if value is not None:
-                    # Always initialize the storage.
-                    if isinstance(value, Container):
-                        # There is no point in obtaining the current value
-                        # stored in the container, since the container is
-                        # shared.
-                        # For safety, we make sure 'refeed' is False, since
-                        # there is no need to refeed the default value.
-                        assert not refeed
-                    else:
-                        c.value = value
-                c.required = required
+                c.required = input.value is None
                 c.implicit = input.implicit
                 # this is a count of how many times the input has been
                 # provided (reinitialized to 0 on __call__)
@@ -521,71 +493,10 @@ def __init__(
                 else:
                     finder[input.name] = DUPLICATE
                 if input.name is None:
-                    n_unnamed_inputs += 1
+                    self._n_unnamed_inputs += 1
                 else:
                     named_inputs.append(input.name)
                 inv_finder[c] = input
-                containers[:1] = []
-
-        self.finder = finder
-        self.inv_finder = inv_finder
-
-        # this class is important in overriding the square-bracket notation:
-        #     fn.value[x]
-        # self reference is available via the closure on the class
-        class ValueAttribute:
-            def __getitem__(self, item):
-                try:
-                    s = finder[item]
-                except KeyError:
-                    raise TypeError(f"Unknown input or state: {item}")
-                if s is DUPLICATE:
-                    raise TypeError(
-                        f"Ambiguous name: {item} - please check the "
-                        "names of the inputs of your function "
-                        "for duplicates."
-                    )
-                if isinstance(s, Container):
-                    return s.value
-                else:
-                    raise NotImplementedError
-
-            def __setitem__(self, item, value):
-                try:
-                    s = finder[item]
-                except KeyError:
-                    # Print informative error message.
-                    msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
-                    raise TypeError(f"Unknown input or state: {item}. {msg}")
-                if s is DUPLICATE:
-                    raise TypeError(
-                        f"Ambiguous name: {item} - please check the "
-                        "names of the inputs of your function "
-                        "for duplicates."
-                    )
-                if isinstance(s, Container):
-                    s.value = value
-                    s.provided += 1
-                else:
-                    s(value)
-
-            def __contains__(self, item):
-                return finder.__contains__(item)
-
-        # this class is important in overriding the square-bracket notation:
-        #     fn.container[x]
-        # self reference is available via the closure on the class
-        class ContainerAttribute:
-            def __getitem__(self, item):
-                return finder[item]
-
-            def __contains__(self, item):
-                return finder.__contains__(item)
-
-            # You cannot set the container
-
-        self._value = ValueAttribute()
-        self._container = ContainerAttribute()
 
         update_storage = [
             container
@@ -595,27 +506,31 @@ def __contains__(self, item):
             if inp.update is not None
         ]
         # Updates are the last inner outputs that are not returned by Function.__call__
-        self.n_returned_outputs = len(self.output_storage) - len(update_storage)
+        self._has_updates = len(update_storage) > 0
+        self._n_returned_outputs = len(self.output_storage) - len(update_storage)
 
         # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
-        self.update_input_storage: tuple[int, Container] = ()
+        self._update_input_storage: tuple[int, Container] = ()
         if getattr(vm, "need_update_inputs", True):
-            self.update_input_storage = tuple(
+            self._update_input_storage = tuple(
                 zip(
-                    range(self.n_returned_outputs, len(output_storage)),
+                    range(self._n_returned_outputs, len(output_storage)),
                     update_storage,
                     strict=True,
                 )
             )
 
+        self._input_storage_data = tuple(
+            container.storage for container in input_storage
+        )
+
         # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
         # After the call, we want to erase (some of) these references, to allow Python to GC them if unused
-        # Required input containers are the non-default inputs, must always be provided again, so we GC them
-        self.clear_input_storage_data = tuple(
+        self._clear_input_storage_data = tuple(
             container.storage for container in input_storage if container.required
         )
         # This is only done when `vm.allow_gc` is True, which can change at runtime.
-        self.clear_output_storage_data = tuple(
+        self._clear_output_storage_data = tuple(
             container.storage
             for container, variable in zip(
                 self.output_storage, self.maker.fgraph.outputs, strict=True
@@ -623,18 +538,11 @@ def __contains__(self, item):
             if variable.owner is not None  # Not a constant output
         )
 
-        for node in self.maker.fgraph.apply_nodes:
-            if isinstance(node.op, HasInnerGraph):
-                self.nodes_with_inner_function.append(node.op)
-
-    def __contains__(self, item):
-        return self.value.__contains__(item)
-
-    def __getitem__(self, item):
-        return self.value[item]
-
-    def __setitem__(self, item, value):
-        self.value[item] = value
+        self._nodes_with_inner_function = [
+            node
+            for node in self.maker.fgraph.apply_nodes
+            if isinstance(node.op, HasInnerGraph)
+        ]
 
     def __copy__(self):
         """
@@ -838,7 +746,6 @@ def checkSV(sv_ori, sv_rpl):
             # check that.
             accept_inplace=True,
             no_fgraph_prep=True,
-            output_keys=maker.output_keys,
             name=name,
         ).create(input_storage, storage_map=new_storage_map)
 
@@ -862,26 +769,126 @@ def checkSV(sv_ori, sv_rpl):
             # to container, to make Function.value and Function.data work well.
             # Replace variable in new maker.inputs by the original ones.
             # So that user can swap SharedVariable in a swapped function
-            container = f_cpy.finder.pop(in_cpy.variable)
+            container = f_cpy._finder.pop(in_cpy.variable)
             if not swapped:
-                f_cpy.finder[in_ori.variable] = container
+                f_cpy._finder[in_ori.variable] = container
                 in_cpy.variable = in_ori.variable
             else:
-                f_cpy.finder[swap[in_ori.variable]] = container
+                f_cpy._finder[swap[in_ori.variable]] = container
                 in_cpy.variable = swap[in_ori.variable]
 
         f_cpy.trust_input = self.trust_input
         f_cpy.unpack_single = self.unpack_single
         return f_cpy
 
-    def _restore_defaults(self):
-        for i, (required, refeed, value) in enumerate(self.defaults):
-            if refeed:
-                if isinstance(value, Container):
-                    value = value.storage[0]
-                self[i] = value
+    def _validate_inputs(self, args, kwargs):
+        input_storage = self.input_storage
+
+        if len(args) + len(kwargs) > len(input_storage):
+            raise TypeError("Too many parameter passed to pytensor function")
+
+        for arg_container in input_storage:
+            arg_container.provided = 0
+
+        # Set positional arguments
+        for arg_container, arg in zip(input_storage, args):
+            try:
+                arg_container.storage[0] = arg_container.type.filter(
+                    arg,
+                    strict=arg_container.strict,
+                    allow_downcast=arg_container.allow_downcast,
+                )
+
+            except Exception as e:
+                i = input_storage.index(arg_container)
+                function_name = "pytensor function"
+                argument_name = "argument"
+                if self.name:
+                    function_name += ' with name "' + self.name + '"'
+                if hasattr(arg, "name") and arg.name:
+                    argument_name += ' with name "' + arg.name + '"'
+                where = get_variable_trace_string(self.maker.inputs[i].variable)
+                if len(e.args) == 1:
+                    e.args = (
+                        "Bad input "
+                        + argument_name
+                        + " to "
+                        + function_name
+                        + f" at index {int(i)} (0-based). {where}"
+                        + e.args[0],
+                    )
+                else:
+                    e.args = (
+                        "Bad input "
+                        + argument_name
+                        + " to "
+                        + function_name
+                        + f" at index {int(i)} (0-based). {where}"
+                    ) + e.args
+                raise
+            arg_container.provided += 1
 
-    def __call__(self, *args, output_subset=None, **kwargs):
+        # Set keyword arguments
+        if kwargs:  # for speed, skip the items for empty kwargs
+            for key, arg in kwargs.items():
+                try:
+                    kwarg_container = self._finder[key]
+                except KeyError:
+                    # Print informative error message.
+                    msg = get_info_on_inputs(self._named_inputs, self._n_unnamed_inputs)
+                    raise TypeError(f"Unknown input: {key}. {msg}")
+                if kwarg_container is DUPLICATE:
+                    raise TypeError(
+                        f"Ambiguous name: {key} - please check the names of the inputs of your function for duplicates."
+                    )
+                kwarg_container.value = arg
+                kwarg_container.provided += 1
+
+        # Collect aliased inputs among the storage space
+        for potential_group in self._potential_aliased_input_groups:
+            args_share_memory: list[list[int]] = []
+            for i in potential_group:
+                i_type = self.maker.inputs[i].variable.type
+                i_val = input_storage[i].storage[0]
+
+                # Check if value is aliased with any of the values in one of the groups
+                for j_group in args_share_memory:
+                    if any(
+                        i_type.may_share_memory(input_storage[j].storage[0], i_val)
+                        for j in j_group
+                    ):
+                        j_group.append(i)
+                        break
+                else:  # no break
+                    # Create a new group
+                    args_share_memory.append([i])
+
+            # Check for groups of more than one argument that share memory
+            for group in args_share_memory:
+                if len(group) > 1:
+                    # copy all but the first
+                    for i in group[1:]:
+                        input_storage[i].storage[0] = copy.copy(
+                            input_storage[i].storage[0]
+                        )
+
+        # Check if inputs are missing, or if inputs were set more than once, or
+        # if we tried to provide inputs that are supposed to be implicit.
+        for arg_container in input_storage:
+            if arg_container.required and not arg_container.provided:
+                raise TypeError(
+                    f"Missing input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}"
+                )
+            if arg_container.provided > 1:
+                raise TypeError(
+                    f"Multiple values for input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}"
+                )
+            if arg_container.implicit and arg_container.provided > 0:
+                raise TypeError(
+                    f"Tried to provide value for implicit input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}"
+                )
+
+    def __call__(self, *args, **kwargs):
         """
         Evaluates value of a function on given arguments.
 
@@ -909,134 +916,30 @@ def __call__(self, *args, output_subset=None, **kwargs):
             List of outputs on indices/keys from ``output_subset`` or all of them,
             if ``output_subset`` is not passed.
         """
-        trust_input = self.trust_input
-        input_storage = self.input_storage
-        vm = self.vm
-        profile = self.profile
-
-        if profile:
+        if self.profile:
             t0 = time.perf_counter()
 
-        if output_subset is not None:
-            warnings.warn("output_subset is deprecated.", FutureWarning)
-            if self.output_keys is not None:
-                output_subset = [self.output_keys.index(key) for key in output_subset]
-
         # Reinitialize each container's 'provided' counter
-        if trust_input:
-            for arg_container, arg in zip(input_storage, args, strict=False):
-                arg_container.storage[0] = arg
+        if self.trust_input:
+            for storage_data, arg in zip(self._input_storage_data, args):
+                storage_data[0] = arg
+            if kwargs:  # for speed, skip the items for empty kwargs
+                for k, arg in kwargs.items():
+                    self._finder[k].storage[0] = arg
         else:
-            for arg_container in input_storage:
-                arg_container.provided = 0
-
-            if len(args) + len(kwargs) > len(input_storage):
-                raise TypeError("Too many parameter passed to pytensor function")
-
-            # Set positional arguments
-            for arg_container, arg in zip(input_storage, args, strict=False):
-                # See discussion about None as input
-                # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
-                if arg is None:
-                    arg_container.storage[0] = arg
-                else:
-                    try:
-                        arg_container.storage[0] = arg_container.type.filter(
-                            arg,
-                            strict=arg_container.strict,
-                            allow_downcast=arg_container.allow_downcast,
-                        )
-
-                    except Exception as e:
-                        i = input_storage.index(arg_container)
-                        function_name = "pytensor function"
-                        argument_name = "argument"
-                        if self.name:
-                            function_name += ' with name "' + self.name + '"'
-                        if hasattr(arg, "name") and arg.name:
-                            argument_name += ' with name "' + arg.name + '"'
-                        where = get_variable_trace_string(self.maker.inputs[i].variable)
-                        if len(e.args) == 1:
-                            e.args = (
-                                "Bad input "
-                                + argument_name
-                                + " to "
-                                + function_name
-                                + f" at index {int(i)} (0-based). {where}"
-                                + e.args[0],
-                            )
-                        else:
-                            e.args = (
-                                "Bad input "
-                                + argument_name
-                                + " to "
-                                + function_name
-                                + f" at index {int(i)} (0-based). {where}"
-                            ) + e.args
-                        self._restore_defaults()
-                        raise
-                arg_container.provided += 1
-
-        # Set keyword arguments
-        if kwargs:  # for speed, skip the items for empty kwargs
-            for k, arg in kwargs.items():
-                self[k] = arg
-
-        if not trust_input:
-            # Collect aliased inputs among the storage space
-            for potential_group in self._potential_aliased_input_groups:
-                args_share_memory: list[list[int]] = []
-                for i in potential_group:
-                    i_type = self.maker.inputs[i].variable.type
-                    i_val = input_storage[i].storage[0]
-
-                    # Check if value is aliased with any of the values in one of the groups
-                    for j_group in args_share_memory:
-                        if any(
-                            i_type.may_share_memory(input_storage[j].storage[0], i_val)
-                            for j in j_group
-                        ):
-                            j_group.append(i)
-                            break
-                    else:  # no break
-                        # Create a new group
-                        args_share_memory.append([i])
-
-                # Check for groups of more than one argument that share memory
-                for group in args_share_memory:
-                    if len(group) > 1:
-                        # copy all but the first
-                        for i in group[1:]:
-                            input_storage[i].storage[0] = copy.copy(
-                                input_storage[i].storage[0]
-                            )
-
-            # Check if inputs are missing, or if inputs were set more than once, or
-            # if we tried to provide inputs that are supposed to be implicit.
-            for arg_container in input_storage:
-                if arg_container.required and not arg_container.provided:
-                    self._restore_defaults()
-                    raise TypeError(
-                        f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
-                    )
-                if arg_container.provided > 1:
-                    self._restore_defaults()
-                    raise TypeError(
-                        f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
-                    )
-                if arg_container.implicit and arg_container.provided > 0:
-                    self._restore_defaults()
-                    raise TypeError(
-                        f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
-                    )
+            self._validate_inputs(args, kwargs)
 
         # Do the actual work
-        if profile:
-            t0_fn = time.perf_counter()
         try:
-            outputs = vm() if output_subset is None else vm(output_subset=output_subset)
+            if self.profile:
+                t0_fn = time.perf_counter()
+                outputs = self.vm()
+                dt_fn = time.perf_counter() - t0_fn
+                self.maker.mode.fn_time += dt_fn
+                self.profile.vm_call_time += dt_fn
+            else:
+                outputs = self.vm()
         except Exception:
-            self._restore_defaults()
             if hasattr(self.vm, "position_of_error"):
                 # this is a new vm-provided function or c linker
                 # they need this because the exception manipulation
@@ -1054,71 +957,39 @@ def __call__(self, *args, output_subset=None, **kwargs):
                 # old-style linkers raise their own exceptions
                 raise
 
-        if profile:
-            dt_fn = time.perf_counter() - t0_fn
-            self.maker.mode.fn_time += dt_fn
-            profile.vm_call_time += dt_fn
-
-        # Retrieve the values that were computed
         if outputs is None:
+            # Not all VMs can return outputs directly (mainly CLinker?)
             outputs = [x.storage[0] for x in self.output_storage]
 
         # Set updates and filter them out from the returned outputs
-        for i, input_storage in self.update_input_storage:
-            input_storage.storage[0] = outputs[i]
-        outputs = outputs[: self.n_returned_outputs]
+        if self._has_updates:
+            for i, input_storage in self._update_input_storage:
+                input_storage.storage[0] = outputs[i]
+            outputs = outputs[: self._n_returned_outputs]
 
         # Remove input and output values from storage data
-        for storage_data in self.clear_input_storage_data:
-            storage_data[0] = None
-        if getattr(vm, "allow_gc", False):
-            for storage_data in self.clear_output_storage_data:
+        if self.vm.allow_gc:
+            for storage_data in self._clear_input_storage_data:
+                storage_data[0] = None
+            for storage_data in self._clear_output_storage_data:
                 storage_data[0] = None
 
-        # Put default values back in the storage
-        if self.has_defaults:
-            self._restore_defaults()
-
-        if profile:
+        if self.profile:
+            profile = self.profile
             dt_call = time.perf_counter() - t0
             pytensor.compile.profiling.total_fct_exec_time += dt_call
             self.maker.mode.call_time += dt_call
             profile.fct_callcount += 1
             profile.fct_call_time += dt_call
-            if hasattr(vm, "update_profile"):
-                vm.update_profile(profile)
+            if hasattr(self.vm, "update_profile"):
+                self.vm.update_profile(profile)
             if profile.ignore_first_call:
                 profile.reset()
                 profile.ignore_first_call = False
 
-        if self.return_none:
-            return None
-
-        if output_subset is not None:
-            outputs = [outputs[i] for i in output_subset]
-
-        if self.output_keys is None:
-            if self.unpack_single:
-                [out] = outputs
-                return out
-            else:
-                return outputs
-        else:
-            output_keys = self.output_keys
-            if output_subset is not None:
-                output_keys = [output_keys[i] for i in output_subset]
-            return dict(zip(output_keys, outputs, strict=True))
-
-    value = property(
-        lambda self: self._value,
-        None,  # this property itself is not settable
-        doc="dictionary-like access to the values associated with Variables",
-    )
-    container = property(
-        lambda self: self._container,
-        None,  # this property itself is not settable
-        doc=("dictionary-like access to the containers associated with Variables"),
-    )
+        return (
+            outputs[0] if self.unpack_single else None if self.return_none else outputs
+        )
 
     def free(self):
         """
@@ -1126,13 +997,16 @@ def free(self):
         """
         # 1.no allow_gc return False
         # 2.has allow_gc, if allow_gc is False, return True
-        if not getattr(self.vm, "allow_gc", True):
+        if not self.vm.allow_gc:
+            for inp_storage in self._clear_input_storage_data:
+                inp_storage[0] = None
+
             storage_map = self.vm.storage_map
             for key, value in storage_map.items():
                 if key.owner is not None:  # Not a constant
                     value[0] = None
 
-            for node in self.nodes_with_inner_function:
+            for node in self._nodes_with_inner_function:
                 if hasattr(node.fn, "free"):
                     node.fn.free()
 
@@ -1157,17 +1031,7 @@ def dprint(self, **kwargs):
 
 # pickling/deepcopy support for Function
 def _pickle_Function(f):
-    # copy of the input storage list
-    ins = list(f.input_storage)
-    input_storage = []
-
-    # strict=False because we are in a hot loop
-    for (input, indices, inputs), (required, refeed, default) in zip(
-        f.indices, f.defaults, strict=False
-    ):
-        input_storage.append(ins[0])
-        del ins[0]
-
+    input_storage = f.input_storage.copy()
     inputs_data = [x.data for x in f.input_storage]
 
     # HACK to detect aliased storage.
@@ -1521,6 +1385,9 @@ def __init__(
         no_fgraph_prep=False,
         trust_input=False,
     ):
+        if output_keys is not None:
+            raise ValueError("output_keys was deprecated")
+
         # Save the provided mode, not the instantiated mode.
         # The instantiated mode don't pickle and if we unpickle an PyTensor
         # function and it get re-compiled, we want the current rewriter to be
@@ -1561,6 +1428,20 @@ def __init__(
 
         # Wrap them in In or Out instances if needed.
         inputs = [self.wrap_in(i) for i in inputs]
+
+        # Remove this after a while
+        if any(
+            (
+                i.value is not None
+                and not isinstance(i.value, Container)
+                and i.update is None
+            )
+            for i in inputs
+        ):
+            raise ValueError(
+                "Inputs with default values were deprecated. Use `functools.partial` instead."
+            )
+
         outputs = [self.wrap_out(o) for o in outputs]
 
         # Check if some input variables are unused
@@ -1620,21 +1501,9 @@ def __init__(
         self.accept_inplace = accept_inplace
         self.function_builder = function_builder
         self.on_unused_input = on_unused_input  # Used for the pickling/copy
-        self.output_keys = output_keys
         self.name = name
         self.trust_input = trust_input
-
         self.required = [(i.value is None) for i in self.inputs]
-        self.refeed = [
-            (
-                i.value is not None
-                and not isinstance(i.value, Container)
-                and i.update is None
-            )
-            for i in self.inputs
-        ]
-        if any(self.refeed):
-            warnings.warn("Inputs with default values are deprecated.", FutureWarning)
 
     def create(self, input_storage=None, storage_map=None):
         """
@@ -1652,7 +1521,6 @@ def create(self, input_storage=None, storage_map=None):
             input_storage = [None] * len(self.inputs)
         # list of independent one-element lists, will be passed to the linker
         input_storage_lists = []
-        defaults = []
 
         # The following loop is to fill in the input_storage_lists and
         # defaults lists.
@@ -1679,35 +1547,15 @@ def create(self, input_storage=None, storage_map=None):
                     )
                 input_storage_lists.append(input_storage_i.storage)
 
-                storage = input_storage[i].storage[0]
-
             else:
                 # Normal case: one new, independent storage unit
                 input_storage_lists.append([input_storage_i])
 
-                storage = input_storage_i
-
             required = self.required[i]
-            refeed = self.refeed[i]
-            # sanity check-- if an input is required it should not
-            # need to be refed
-            assert not (required and refeed)
 
             # shared variables need neither be input by the user nor refed
             if input.shared:
                 assert not required
-                assert not refeed
-                storage = None
-
-            # if an input is required, it never need be refed
-            if required:
-                storage = None
-
-            # make sure that we only store a value if we actually need it
-            if storage is not None:
-                assert refeed or not required
-
-            defaults.append((required, refeed, storage))
 
         # Get a function instance
         start_linker = time.perf_counter()
@@ -1730,16 +1578,14 @@ def create(self, input_storage=None, storage_map=None):
             self.profile.import_time += import_time
 
         fn = self.function_builder(
-            _fn,
-            _i,
-            _o,
-            self.indices,
-            self.outputs,
-            defaults,
-            self.unpack_single,
-            self.return_none,
-            self.output_keys,
-            self,
+            vm=_fn,
+            input_storage=_i,
+            output_storage=_o,
+            indices=self.indices,
+            outputs=self.outputs,
+            unpack_single=self.unpack_single,
+            return_none=self.return_none,
+            maker=self,
             trust_input=self.trust_input,
             name=self.name,
         )
@@ -1809,7 +1655,7 @@ def orig_function(
         else:
             outputs = FunctionMaker.wrap_out(outputs)
 
-    defaults = [getattr(input, "value", None) for input in inputs]
+    shared_variable_containers = [getattr(input, "value", None) for input in inputs]
 
     if isinstance(mode, list | tuple):
         raise ValueError("We do not support the passing of multiple modes")
@@ -1830,7 +1676,7 @@ def orig_function(
             trust_input=trust_input,
         )
         with config.change_flags(compute_test_value="off"):
-            fn = m.create(defaults)
+            fn = m.create(shared_variable_containers)
     finally:
         if profile and fn:
             t2 = time.perf_counter()
diff --git a/pytensor/compile/io.py b/pytensor/compile/io.py
index 9ce0421235..07554929ff 100644
--- a/pytensor/compile/io.py
+++ b/pytensor/compile/io.py
@@ -182,9 +182,6 @@ def __init__(
         borrow=None,
         shared=False,
     ):
-        # if shared, an input's value comes from its persistent
-        # storage, not from a default stored in the function or from
-        # the caller
         self.shared = shared
 
         if borrow is None:
@@ -204,6 +201,13 @@ def __init__(
                 "overwritten.",
             )
 
+        if value is not None and not isinstance(value, Container):
+            from pytensor.compile.sharedvalue import SharedVariable
+
+            if not isinstance(value, SharedVariable):
+                # This is to catch use of old API to pass default values
+                raise ValueError("Inputs with default values are deprecated")
+
         if implicit is None:
             from pytensor.compile.sharedvalue import SharedVariable
 
diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py
index 5597ddddfb..2ab9146576 100644
--- a/pytensor/link/basic.py
+++ b/pytensor/link/basic.py
@@ -558,6 +558,7 @@ def f():
                     raise_with_op(self.fgraph, node, *thunks)
 
         f.thunk_groups = thunk_groups
+        f.allow_gc = len(self.linkers) == 1
 
         return f, inputs0, outputs0
 
@@ -677,6 +678,8 @@ def thunk(
             for o_storage, o_val in zip(thunk_outputs, outputs):
                 o_storage[0] = o_val
 
+            return outputs
+
         thunk.inputs = thunk_inputs
         thunk.outputs = thunk_outputs
         thunk.lazy = False
diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py
index d509bd1d76..1c08a21f04 100644
--- a/pytensor/link/c/basic.py
+++ b/pytensor/link/c/basic.py
@@ -1188,6 +1188,7 @@ def make_thunk(
 
         res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)
         res.nodes = self.node_order
+        res.allow_gc = False
         return res, in_storage, out_storage
 
     def cmodule_key(self):
@@ -1875,9 +1876,10 @@ def make_all(
             fgraph,
             thunks,
             order,
-            post_thunk_old_storage,
+            post_thunk_old_storage=post_thunk_old_storage,
             no_recycling=no_recycling,
             nice_errors=self.nice_errors,
+            output_storage=output_storage,
         )
 
         f.allow_gc = self.allow_gc
diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py
index 019acdd0ca..468422d01f 100644
--- a/pytensor/link/utils.py
+++ b/pytensor/link/utils.py
@@ -145,9 +145,11 @@ def streamline(
     fgraph: FunctionGraph,
     thunks: Sequence[Callable[[], None]],
     order: Sequence[Apply],
+    *,
     post_thunk_old_storage: list["StorageCellType"] | None = None,
     no_recycling: list["StorageCellType"] | None = None,
     nice_errors: bool = True,
+    output_storage: list["StorageCellType"],
 ) -> "BasicThunkType":
     """Construct a single thunk that runs a list of thunks.
 
@@ -197,6 +199,7 @@ def streamline_default_f():
                     thunk()
                     for old_s in old_storage:
                         old_s[0] = None
+                return [out[0] for out in output_storage]
             except Exception:
                 raise_with_op(fgraph, node, thunk)
 
@@ -212,6 +215,7 @@ def streamline_nice_errors_f():
                     thunk()
             except Exception:
                 raise_with_op(fgraph, node, thunk)
+            return [out[0] for out in output_storage]
 
         f = streamline_nice_errors_f
     else:
@@ -222,6 +226,7 @@ def streamline_fast_f():
                 x[0] = None
             for thunk in thunks:
                 thunk()
+            return [out[0] for out in output_storage]
 
         f = streamline_fast_f
     return f
diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py
index 2c3f404449..0a757da575 100644
--- a/pytensor/scan/op.py
+++ b/pytensor/scan/op.py
@@ -1353,20 +1353,8 @@ def prepare_fgraph(self, fgraph):
 
                         preallocated_mitmot_outs.append(output_idx)
 
-                        # Make it so that the input is automatically updated to
-                        # the output value, possibly inplace, at the end of the
-                        # function execution. Also, since an update is defined,
-                        # a default value must also be (this is verified by
-                        # DebugMode).
-                        # TODO FIXME: Why do we need a "default value" here?
-                        # This sounds like a serious design issue.
-                        default_shape = tuple(
-                            s if s is not None else 0 for s in inp.type.shape
-                        )
-                        default_val = np.empty(default_shape, dtype=inp.type.dtype)
                         wrapped_inp = In(
                             variable=inp,
-                            value=default_val,
                             update=fgraph.outputs[output_idx],
                         )
                         update_mapping[output_idx] = input_idx
diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py
index 9bb2f0a731..0591634b05 100644
--- a/pytensor/tensor/type.py
+++ b/pytensor/tensor/type.py
@@ -199,11 +199,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
                     # (do not try to convert the data)
                     up_dtype = ps.upcast(self.dtype, data.dtype)
                     if up_dtype == self.dtype:
-                        # Bug in the following line when data is a
-                        # scalar array, see
-                        # http://projects.scipy.org/numpy/ticket/1611
-                        # data = data.astype(self.dtype)
-                        data = np.asarray(data, dtype=self.dtype)
+                        data = data.astype(self.dtype)
                     if up_dtype != self.dtype:
                         err_msg = (
                             f"{self} cannot store a value of dtype {data.dtype} without "
@@ -270,8 +266,6 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
                 f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
             )
 
-        if self.filter_checks_isfinite and not np.all(np.isfinite(data)):
-            raise ValueError("Non-finite elements not allowed")
         return data
 
     def filter_variable(self, other, allow_convert=True):
diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py
index d1f94dd689..018f9a40f0 100644
--- a/tests/compile/function/test_function.py
+++ b/tests/compile/function/test_function.py
@@ -11,6 +11,7 @@
 from pytensor.compile.function import function, function_dump
 from pytensor.compile.io import In
 from pytensor.configdefaults import config
+from pytensor.link.basic import Container
 from pytensor.npy_2_compat import UintOverflowError
 from pytensor.tensor.type import (
     bscalar,
@@ -119,7 +120,9 @@ def test_in_mutable(self):
 
     def test_in_update(self):
         a = dscalar("a")
-        f = function([In(a, value=0.0, update=a + 1)], a, mode="FAST_RUN")
+        # A shared variable by any other name
+        c = Container(a, storage=[np.array(0.0)])
+        f = function([In(a, value=c, implicit=True, update=a + 1)], a, mode="FAST_RUN")
 
         # Ensure that, through the executions of the function, the state of the
         # input is persistent and is updated as it should
@@ -140,7 +143,8 @@ def test_in_update_shared(self):
         # updates in the same function behaves as expected
         shared_var = shared(1.0)
         a = dscalar("a")
-        a_wrapped = In(a, value=0.0, update=shared_var)
+        container = Container(a, storage=[np.array(0.0)])
+        a_wrapped = In(a, value=container, update=shared_var)
         f = function([a_wrapped], [], updates={shared_var: a}, mode="FAST_RUN")
 
         # Ensure that, through the executions of the function, the state of
diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py
index 3e23b12f74..616237d5a7 100644
--- a/tests/compile/function/test_pfunc.py
+++ b/tests/compile/function/test_pfunc.py
@@ -59,7 +59,7 @@ def test_doc(self):
         a = lscalar()
         b = shared(1)
         f1 = pfunc([a], (a + b))
-        f2 = pfunc([In(a, value=44)], a + b, updates={b: b + 1})
+        f2 = pfunc([In(a)], a + b, updates={b: b + 1})
         assert b.get_value() == 1
         assert f1(3) == 4
         assert f2(3) == 4
diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py
index 0990dbeca0..7e2110a26e 100644
--- a/tests/compile/function/test_types.py
+++ b/tests/compile/function/test_types.py
@@ -15,7 +15,7 @@
 from pytensor.graph.basic import Constant
 from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
 from pytensor.graph.utils import MissingInputError
-from pytensor.link.vm import VMLinker
+from pytensor.link.basic import Container
 from pytensor.printing import debugprint
 from pytensor.tensor.math import dot, tanh
 from pytensor.tensor.math import sum as pt_sum
@@ -193,96 +193,6 @@ def test_naming_rule2(self):
             # got unexpected keyword argument 'x'
             f(5.0, x=9)
 
-    def test_naming_rule3(self):
-        a = scalar()  # the a is for 'anonymous' (un-named).
-        x, s = scalars("xs")
-
-        # x's name is not ignored (as in test_naming_rule2) because a has a default value.
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function([x, In(a, value=1.0), s], a / s + x)
-        assert f(9, 2, 4) == 9.5  # can specify all args in order
-        assert f(9, 2, s=4) == 9.5  # can give s as kwarg
-        assert f(9, s=4) == 9.25  # can give s as kwarg, get default a
-        assert f(x=9, s=4) == 9.25  # can give s as kwarg, omit a, x as kw
-        with pytest.raises(TypeError):
-            # got unexpected keyword argument 'a'
-            f(x=9, a=2, s=4)
-        with pytest.raises(TypeError):
-            # takes exactly 3 non-keyword arguments (0 given)
-            f()
-        with pytest.raises(TypeError):
-            # takes exactly 3 non-keyword arguments (1 given)
-            f(x=9)
-
-    def test_naming_rule4(self):
-        a = scalar()  # the a is for 'anonymous' (un-named).
-        x, s = scalars("xs")
-
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function([x, In(a, value=1.0, name="a"), s], a / s + x)
-
-        assert f(9, 2, 4) == 9.5  # can specify all args in order
-        assert f(9, 2, s=4) == 9.5  # can give s as kwarg
-        assert f(9, s=4) == 9.25  # can give s as kwarg, get default a
-        assert f(9, a=2, s=4) == 9.5  # can give s as kwarg, a as kwarg
-        assert f(x=9, a=2, s=4) == 9.5  # can give all kwargs
-        assert f(x=9, s=4) == 9.25  # can give all kwargs
-        with pytest.raises(TypeError):
-            # takes exactly 3 non-keyword arguments (0 given)
-            f()
-        with pytest.raises(TypeError):
-            # got multiple values for keyword argument 'x'
-            f(5.0, x=9)
-
-    @pytest.mark.parametrize(
-        "mode",
-        [
-            Mode(
-                linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
-                optimizer="fast_compile",
-            ),
-            Mode(
-                linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
-                optimizer="fast_run",
-            ),
-            Mode(linker="cvm", optimizer="fast_compile"),
-            Mode(linker="cvm", optimizer="fast_run"),
-        ],
-    )
-    def test_state_access(self, mode):
-        a = scalar()
-        x, s = scalars("xs")
-
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
-                s + a * x,
-                mode=mode,
-            )
-
-        assert f[a] == 1.0
-        assert f[s] == 0.0
-
-        assert f(3.0) == 3.0
-        assert f[s] == 3.0
-        assert f(3.0, a=2.0) == 9.0  # 3.0 + 2*3.0
-
-        assert (
-            f[a] == 1.0
-        )  # state hasn't changed permanently, we just overrode it last line
-        assert f[s] == 9.0
-
-        f[a] = 5.0
-        assert f[a] == 5.0
-        assert f(3.0) == 24.0  # 9 + 3*5
-        assert f[s] == 24.0
-
     def test_same_names(self):
         a, x, s = scalars("xxx")
         # implicit names would cause error.  What do we do?
@@ -300,9 +210,9 @@ def test_weird_names(self):
         def t():
             f = function(
                 [
-                    In(a, name={"adsf", ()}, value=1.0),
-                    In(x, name=(), value=2.0),
-                    In(s, name=scalar(), value=3.0),
+                    In(a, name={"adsf", ()}),
+                    In(x, name=()),
+                    In(s, name=scalar()),
                 ],
                 a + x + s,
             )
@@ -315,47 +225,29 @@ def test_copy(self):
         a = scalar()
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, name="s"),
+            ],
+            s + a * x,
+        )
 
-            g = copy.copy(f)
+        g = f.copy()
 
         assert f.unpack_single == g.unpack_single
         assert f.trust_input == g.trust_input
 
-        assert g.container[x].storage is not f.container[x].storage
-        assert g.container[a].storage is not f.container[a].storage
-        assert g.container[s].storage is not f.container[s].storage
-
-        # Should not have been copied
-        assert g.value[a] is f.value[a]
-
-        # Should have been copied because it is mutable
-        assert g.value[s] is not f.value[s]
-
-        # Their contents should be equal, though
-        assert np.array_equal(g.value[s], f.value[s])
+        assert g._finder[x].storage is not f._finder[x].storage
+        assert g._finder[a].storage is not f._finder[a].storage
+        assert g._finder[s].storage is not f._finder[s].storage
 
-        # They should be in sync, default value should be copied
-        assert np.array_equal(f(2, 1), g(2))
+        assert g._finder[a].value is None and f._finder[a].value is None
+        assert g._finder[s].value is None and f._finder[s].value is None
 
-        # They should be in sync, default value should be copied
-        assert np.array_equal(f(2, 1), g(2))
-
-        # Put them out of sync
-        f(1, 2)
-
-        # They should not be equal anymore
-        assert not np.array_equal(f(1, 2), g(1, 2))
+        assert np.array_equal(f(2, 1, 0), g(2, 1, 0))
+        assert np.array_equal(f(2, 1, 0), g(2, 1, 0))
 
     def test_copy_share_memory(self):
         x = fscalar("x")
@@ -519,88 +411,90 @@ def test_shared_state0(self):
         a = scalar()  # the a is for 'anonymous' (un-named).
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
-            g = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=f.container[s], update=s - a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(
+                    s,
+                    value=Container(s, storage=[np.array(0.0)]),
+                    update=s + a * x,
+                    mutable=True,
+                ),
+            ],
+            s + a * x,
+        )
+        g = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, value=f._finder[s], update=s - a * x, mutable=True),
+            ],
+            s + a * x,
+        )
 
         f(1, 2)
-        assert f[s] == 2
-        assert g[s] == 2
+        assert f._finder[s].value == 2
+        assert g._finder[s].value == 2
         g(1, 2)
-        assert f[s] == 0
-        assert g[s] == 0
+        assert f._finder[s].value == 0
+        assert g._finder[s].value == 0
 
     def test_shared_state1(self):
         a = scalar()  # the a is for 'anonymous' (un-named).
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
-            g = function(
-                [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(
+                    s,
+                    value=Container(s, storage=[np.array(0.0)]),
+                    update=s + a * x,
+                    mutable=True,
+                ),
+            ],
+            s + a * x,
+        )
+        g = function([x, In(a, name="a"), In(s, value=f._finder[s])], s + a * x)
 
         f(1, 2)
-        assert f[s] == 2
-        assert g[s] == 2
+        assert f._finder[s].value == 2
+        assert g._finder[s].value == 2
         f(1, 2)
         g(1, 2)
-        assert f[s] == 4
-        assert g[s] == 4
+        assert f._finder[s].value == 4
+        assert g._finder[s].value == 4
 
     def test_shared_state2(self):
         a = scalar()  # the a is for 'anonymous' (un-named).
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=False),
-                ],
-                s + a * x,
-            )
-            g = function(
-                [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(
+                    s,
+                    value=Container(s, storage=[np.array(0.0)]),
+                    update=s + a * x,
+                    mutable=False,
+                ),
+            ],
+            s + a * x,
+        )
+        g = function([x, In(a, name="a"), In(s, value=f._finder[s])], s + a * x)
 
         f(1, 2)
-        assert f[s] == 2
-        assert g[s] == 2
+        assert f._finder[s].value == 2
+        assert g._finder[s].value == 2
         f(1, 2)
-        assert f[s] == 4
-        assert g[s] == 4
+        assert f._finder[s].value == 4
+        assert g._finder[s].value == 4
         g(1, 2)  # has no effect on state
-        assert f[s] == 4
-        assert g[s] == 4
+        assert f._finder[s].value == 4
+        assert g._finder[s].value == 4
 
     def test_shared_state_not_implicit(self):
         # This test is taken from the documentation in
@@ -608,18 +502,20 @@ def test_shared_state_not_implicit(self):
         # behavior is still intended the doc and the test should both be
         # updated accordingly.
         x, s = scalars("xs")
-        inc = function([x, In(s, update=(s + x), value=10.0)], [])
+        inc = function(
+            [x, In(s, update=(s + x), value=Container(s, storage=[np.array(10.0)]))], []
+        )
         dec = function(
-            [x, In(s, update=(s - x), value=inc.container[s], implicit=False)], []
+            [x, In(s, update=(s - x), value=inc._finder[s], implicit=False)], []
         )
-        assert dec[s] is inc[s]
-        inc[s] = 2
-        assert dec[s] == 2
+        assert dec._finder[s].value is inc._finder[s].value
+        inc._finder[s].value = 2
+        assert dec._finder[s].value == 2
         dec(1)
-        assert inc[s] == 1
+        assert inc._finder[s].value == 1
         dec(1, 0)
-        assert inc[s] == -1
-        assert dec[s] == -1
+        assert inc._finder[s].value == -1
+        assert dec._finder[s].value == -1
 
     def test_constant_output(self):
         # Test that if the output is a constant, we respect the pytensor memory interface
@@ -736,22 +632,6 @@ def test_free(self):
             if not isinstance(key, Constant):
                 assert val[0] is None
 
-    def test_default_values(self):
-        # Check that default values are restored
-        # when an exception occurs in interactive mode.
-
-        a, b = dscalars("a", "b")
-        c = a + b
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            funct = function([In(a, name="first"), In(b, value=1, name="second")], c)
-        x = funct(first=1)
-        try:
-            funct(second=2)
-        except TypeError:
-            assert funct(first=1) == x
-
     def test_check_for_aliased_inputs(self):
         b = np.random.random((5, 4))
         s1 = shared(b)
@@ -802,78 +682,11 @@ def test_output_dictionary(self):
         # Tests that function works when outputs is a dictionary
 
         x = scalar()
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
-
-        outputs = f(10.0)
-
-        assert outputs["a"] == 10.0
-        assert outputs["b"] == 30.0
-        assert outputs["1"] == 40.0
-        assert outputs["c"] == 20.0
-
-    def test_input_named_variables(self):
-        # Tests that named variables work when outputs is a dictionary
-
-        x = scalar("x")
-        y = scalar("y")
-
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            f = function([x, y], outputs={"a": x + y, "b": x * y})
-
-        assert f(2, 4) == {"a": 6, "b": 8}
-        assert f(2, y=4) == f(2, 4)
-        assert f(x=2, y=4) == f(2, 4)
-
-    def test_output_order_sorted(self):
-        # Tests that the output keys are sorted correctly.
-
-        x = scalar("x")
-        y = scalar("y")
-        z = scalar("z")
-        e1 = scalar("1")
-        e2 = scalar("2")
-
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            f = function(
-                [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2}
-            )
-
-        assert "1" in str(f.outputs[0])
-        assert "2" in str(f.outputs[1])
-        assert "x" in str(f.outputs[2])
-        assert "y" in str(f.outputs[3])
-        assert "z" in str(f.outputs[4])
-
-    def test_composing_function(self):
-        # Tests that one can compose two pytensor functions when the outputs are
-        # provided in a dictionary.
-
-        x = scalar("x")
-        y = scalar("y")
-
-        a = x + y
-        b = x * y
-
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            f = function([x, y], outputs={"a": a, "b": b})
-
-        a = scalar("a")
-        b = scalar("b")
-
-        l = a + b
-        r = a * b
-
-        g = function([a, b], outputs=[l, r])
-
-        result = g(**f(5, 7))
-
-        assert result[0] == 47.0
-        assert result[1] == 420.0
+        with pytest.raises(ValueError, match="output_keys was deprecated"):
+            function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
 
     def test_output_list_still_works(self):
         # Test that function works if outputs is a list.
-
         x = scalar("x")
 
         f = function([x], outputs=[x * 3, x * 2, x * 4, x])
@@ -911,17 +724,14 @@ def test_deepcopy(self):
         a = scalar()  # the a is for 'anonymous' (un-named).
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a", mutable=True),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a", mutable=True),
+                In(s, update=s + a * x, mutable=True),
+            ],
+            s + a * x,
+        )
         try:
             g = copy.deepcopy(f)
         except NotImplementedError as e:
@@ -933,12 +743,10 @@ def test_deepcopy(self):
         # print [(k, id(k)) for k in f.finder]
         # print [(k, id(k)) for k in g.finder]
 
-        assert g.container[0].storage is not f.container[0].storage
-        assert g.container[1].storage is not f.container[1].storage
-        assert g.container[2].storage is not f.container[2].storage
-        assert x not in g.container
-        assert x not in g.value
-        assert len(f.defaults) == len(g.defaults)
+        assert g._finder[0].storage is not f._finder[0].storage
+        assert g._finder[1].storage is not f._finder[1].storage
+        assert g._finder[2].storage is not f._finder[2].storage
+        assert x not in g._finder
         # Shared variable is the first input
         assert (
             f._potential_aliased_input_groups
@@ -947,45 +755,24 @@ def test_deepcopy(self):
         )
         assert f.name == g.name
         assert f.maker.fgraph.name == g.maker.fgraph.name
-        # print(f"{f.defaults = }")
-        # print(f"{g.defaults = }")
-        for (f_req, f_feed, f_val), (g_req, g_feed, g_val) in zip(
-            f.defaults, g.defaults, strict=True
-        ):
-            assert f_req == g_req and f_feed == g_feed and f_val == g_val
-
-        assert g.value[1] is not f.value[1]  # should not have been copied
-        assert (
-            g.value[2] is not f.value[2]
-        )  # should have been copied because it is mutable.
-        assert not (g.value[2] != f.value[2]).any()  # its contents should be identical
-
-        assert f(2, 1) == g(
-            2
-        )  # they should be in sync, default value should be copied.
-        assert f(2, 1) == g(
-            2
-        )  # they should be in sync, default value should be copied.
-        f(1, 2)  # put them out of sync
-        assert f(1, 2) != g(1, 2)  # they should not be equal anymore.
-        g(1, 2)  # put them back in sync
-        assert f(3) == g(3)  # They should be in sync again.
+
+        assert g._finder[1].value is None and f._finder[1].value is None
+        assert g._finder[2].value is None and f._finder[2].value is None
+
+        assert f(2, 1, 0) == g(2, 1, 0)
 
     def test_deepcopy_trust_input(self):
         a = dscalar()  # the a is for 'anonymous' (un-named).
         x, s = dscalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, update=s + a * x, mutable=True),
+            ],
+            s + a * x,
+        )
         f.trust_input = True
         try:
             g = copy.deepcopy(f)
@@ -995,35 +782,19 @@ def test_deepcopy_trust_input(self):
             else:
                 raise
         assert f.trust_input is g.trust_input
-        f(np.asarray(2.0))
+        f(np.array(2.0), np.array(1.0), np.array(0.0))
         with pytest.raises((ValueError, AttributeError, InvalidValueError)):
-            f(2.0)
-        g(np.asarray(2.0))
+            f(2.0, np.array(1.0), np.array(0.0))
+        g(np.array(2.0), np.array(1.0), np.array(0.0))
         with pytest.raises((ValueError, AttributeError, InvalidValueError)):
-            g(2.0)
-
-    def test_output_keys(self):
-        x = vector()
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            f = function([x], {"vec": x**2})
-        o = f([2, 3, 4])
-        assert isinstance(o, dict)
-        assert np.allclose(o["vec"], [4, 9, 16])
-        with pytest.warns(FutureWarning, match="output_keys is deprecated."):
-            g = copy.deepcopy(f)
-        o = g([2, 3, 4])
-        assert isinstance(o, dict)
-        assert np.allclose(o["vec"], [4, 9, 16])
+            g(2.0, np.array(1.0), np.array(0.0))
 
     def test_deepcopy_shared_container(self):
         # Ensure that shared containers remain shared after a deep copy.
         a, x = scalars("ax")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            h = function([In(a, value=0.0)], a)
-        f = function([x, In(a, value=h.container[a], implicit=True)], x + a)
+        h = function([In(a, value=Container(a, storage=[np.array(0.0)]))], a)
+        f = function([x, In(a, value=h._finder[a], implicit=True)], x + a)
 
         try:
             memo = {}
@@ -1037,26 +808,23 @@ def test_deepcopy_shared_container(self):
                 return
             else:
                 raise
-        h[a] = 1
-        hc[ac] = 2
-        assert f[a] == 1
-        assert fc[ac] == 2
+        h._finder[a].value = 1
+        hc._finder[ac].value = 2
+        assert f._finder[a].value == 1
+        assert fc._finder[ac].value == 2
 
     def test_pickle(self):
         a = scalar()  # the a is for 'anonymous' (un-named).
         x, s = scalars("xs")
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, update=s + a * x, mutable=True),
+            ],
+            s + a * x,
+        )
 
         try:
             # Note that here we also test protocol 0 on purpose, since it
@@ -1072,26 +840,14 @@ def test_pickle(self):
         # print [(k, id(k)) for k in f.finder]
         # print [(k, id(k)) for k in g.finder]
 
-        assert g.container[0].storage is not f.container[0].storage
-        assert g.container[1].storage is not f.container[1].storage
-        assert g.container[2].storage is not f.container[2].storage
-        assert x not in g.container
-        assert x not in g.value
+        assert g._finder[0].storage is not f._finder[0].storage
+        assert g._finder[1].storage is not f._finder[1].storage
+        assert g._finder[2].storage is not f._finder[2].storage
+        assert x not in g._finder
 
-        assert g.value[1] is not f.value[1]  # should not have been copied
-        assert (
-            g.value[2] is not f.value[2]
-        )  # should have been copied because it is mutable.
-        assert not (g.value[2] != f.value[2]).any()  # its contents should be identical
-
-        assert f(2, 1) == g(
-            2
-        )  # they should be in sync, default value should be copied.
-        assert f(2, 1) == g(
-            2
-        )  # they should be in sync, default value should be copied.
-        f(1, 2)  # put them out of sync
-        assert f(1, 2) != g(1, 2)  # they should not be equal anymore.
+        assert g._finder[1].value is None and f._finder[1].value is None
+        assert g._finder[2].value is None and f._finder[2].value is None
+        assert f(2, 1, 0) == g(2, 1, 0)
 
     def test_optimizations_preserved(self):
         a = dvector()  # the a is for 'anonymous' (un-named).
@@ -1144,42 +900,38 @@ def test_multiple_functions(self):
         # some derived thing, whose inputs aren't all in the list
         list_of_things.append(a * x + s)
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f1 = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f1 = function(
+            [
+                x,
+                In(a, name="a"),
+                In(
+                    s,
+                    value=Container(s, storage=[np.array(0.0)]),
+                    update=s + a * x,
+                    mutable=True,
+                ),
+            ],
+            s + a * x,
+        )
         list_of_things.append(f1)
 
         # now put in a function sharing container with the previous one
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f2 = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=f1.container[s], update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        f2 = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, value=f1._finder[s], update=s + a * x, mutable=True),
+            ],
+            s + a * x,
+        )
         list_of_things.append(f2)
 
-        assert isinstance(f2.container[s].storage, list)
-        assert f2.container[s].storage is f1.container[s].storage
+        assert isinstance(f2._finder[s].storage, list)
+        assert f2._finder[s].storage is f1._finder[s].storage
 
         # now put in a function with non-scalar
-        v_value = np.asarray([2, 3, 4.0], dtype=config.floatX)
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            f3 = function([x, In(v, value=v_value)], x + v)
+        value = Container(v, storage=[np.asarray([2, 3, 4.0], dtype=config.floatX)])
+        f3 = function([x, In(v, value=value)], x + v)
         list_of_things.append(f3)
 
         # try to pickle the entire things
@@ -1214,18 +966,18 @@ def test_multiple_functions(self):
             assert nl[i] != ol[i]
 
         # looking at function number 1, input 's'
-        assert nl[4][nl[0]] is not ol[4][ol[0]]
-        assert nl[4][nl[0]] == ol[4][ol[0]]
-        assert nl[4](3) == ol[4](3)
+        assert nl[4]._finder[nl[0]].value is not ol[4]._finder[ol[0]].value
+        assert nl[4]._finder[nl[0]].value == ol[4]._finder[ol[0]].value
+        assert nl[4](3, 1) == ol[4](3, 1)
 
         # looking at function number 2, input 's'
         # make sure it's shared with the first function
-        assert ol[4].container[ol[0]].storage is ol[5].container[ol[0]].storage
-        assert nl[4].container[nl[0]].storage is nl[5].container[nl[0]].storage
-        assert nl[5](3) == ol[5](3)
-        assert nl[4].value[nl[0]] == 6
+        assert ol[4]._finder[ol[0]].storage is ol[5]._finder[ol[0]].storage
+        assert nl[4]._finder[nl[0]].storage is nl[5]._finder[nl[0]].storage
+        assert nl[5](3, 1) == ol[5](3, 1)
+        assert nl[4]._finder[nl[0]].value == 6
 
-        assert np.all(nl[6][nl[2]] == np.asarray([2, 3.0, 4]))
+        assert np.all(nl[6]._finder[nl[2]].value == np.array([2, 3.0, 4]))
 
     def test_broken_pickle_with_shared(self):
         saves = []
@@ -1279,7 +1031,7 @@ def exc_message(e):
 
     def test_pickle_class_with_functions(self):
         blah = SomethingToPickle()
-        assert blah.f2.container[blah.s].storage is blah.f1.container[blah.s].storage
+        assert blah.f2._finder[blah.s].storage is blah.f1._finder[blah.s].storage
 
         try:
             blah2 = copy.deepcopy(blah)
@@ -1289,14 +1041,12 @@ def test_pickle_class_with_functions(self):
             else:
                 raise
 
-        assert (
-            blah2.f2.container[blah2.s].storage is blah2.f1.container[blah2.s].storage
-        )
+        assert blah2.f2._finder[blah2.s].storage is blah2.f1._finder[blah2.s].storage
 
-        assert blah.f1[blah.s] == blah2.f1[blah2.s]
+        assert blah.f1._finder[blah.s].value == blah2.f1._finder[blah2.s].value
 
-        blah.f2(5)
-        assert blah.f1[blah.s] != blah2.f1[blah2.s]
+        blah.f2(5, 1)
+        assert blah.f1._finder[blah.s].value != blah2.f1._finder[blah2.s].value
 
 
 class SomethingToPickle:
@@ -1311,29 +1061,28 @@ def __init__(self):
 
         self.e = a * x + s
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            self.f1 = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=0.0, update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        self.f1 = function(
+            [
+                x,
+                In(a, name="a"),
+                In(
+                    s,
+                    value=Container(s, storage=[np.array(0.0)]),
+                    update=s + a * x,
+                    mutable=True,
+                ),
+            ],
+            s + a * x,
+        )
 
-        with pytest.warns(
-            FutureWarning, match="Inputs with default values are deprecated."
-        ):
-            self.f2 = function(
-                [
-                    x,
-                    In(a, value=1.0, name="a"),
-                    In(s, value=self.f1.container[s], update=s + a * x, mutable=True),
-                ],
-                s + a * x,
-            )
+        self.f2 = function(
+            [
+                x,
+                In(a, name="a"),
+                In(s, value=self.f1._finder[s], update=s + a * x, mutable=True),
+            ],
+            s + a * x,
+        )
 
 
 def test_empty_givens_updates():
@@ -1347,7 +1096,7 @@ def test_empty_givens_updates():
     function([In(x)], y, updates={})
 
 
-@pytest.mark.parametrize("trust_input", [True, False])
+@pytest.mark.parametrize("trust_input", [True, False], ids=lambda x: f"trust_input={x}")
 def test_minimal_random_function_call_benchmark(trust_input, benchmark):
     rng = random_generator_type()
     x = normal(rng=rng, size=(100,))
@@ -1357,3 +1106,17 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
 
     rng_val = np.random.default_rng()
     benchmark(f, rng_val)
+
+
+@pytest.mark.parametrize("trust_input", [True, False], ids=lambda x: f"trust_input={x}")
+@pytest.mark.parametrize("linker", ["c", "cvm", "cvm_nogc"])
+def test_overhead_benchmark(trust_input, linker, benchmark):
+    x = pt.vector("x")
+    fn = function(
+        [In(x, borrow=True)],
+        Out(x, borrow=True),
+        trust_input=trust_input,
+        mode=Mode(linker=linker, optimizer=None),
+    )
+    x_test = np.zeros(10)
+    benchmark(fn, x_test)
diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py
index fae76fab0d..845eb44e52 100644
--- a/tests/compile/test_debugmode.py
+++ b/tests/compile/test_debugmode.py
@@ -9,11 +9,9 @@
     BadThunkOutput,
     BadViewMap,
     DebugMode,
-    InvalidValueError,
     StochasticOrder,
 )
 from pytensor.compile.function import function
-from pytensor.compile.mode import predefined_modes
 from pytensor.configdefaults import config
 from pytensor.graph.basic import Apply, Variable
 from pytensor.graph.features import BadOptimization
@@ -21,8 +19,8 @@
 from pytensor.graph.rewriting.basic import node_rewriter
 from pytensor.graph.rewriting.db import EquilibriumDB
 from pytensor.link.c.op import COp
-from pytensor.tensor.math import add, dot, log
-from pytensor.tensor.type import TensorType, dvector, fmatrix, fvector, scalar, vector
+from pytensor.tensor.math import add, dot
+from pytensor.tensor.type import dvector, fmatrix, fvector, scalar
 from tests import unittest_tools as utt
 
 
@@ -553,59 +551,6 @@ def perform(self, node, inp, out):
         # f([1,2,3,4],[5,6,7,8])
 
 
-class TestCheckIsfinite:
-    def setup_method(self):
-        self.old_ts = TensorType.filter_checks_isfinite
-        self.old_dm = predefined_modes["DEBUG_MODE"].check_isfinite
-
-    def teardown_method(self):
-        TensorType.filter_checks_isfinite = self.old_ts
-        predefined_modes["DEBUG_MODE"].check_isfinite = self.old_dm
-
-    def test_check_isfinite(self):
-        x = vector()
-        f = function([x], (x + 2) * 5, mode="DEBUG_MODE")
-        g = function([x], log(x), mode="DEBUG_MODE")
-
-        # this should work
-        f(np.log([3, 4, 5]).astype(config.floatX))
-
-        # if TensorType.filter_checks_isfinite were true, these would raise
-        # ValueError
-        # if not, DebugMode will check internally, and raise InvalidValueError
-        # passing an invalid value as an input should trigger ValueError
-        with pytest.raises(InvalidValueError):
-            f(np.log([3, -4, 5]).astype(config.floatX))
-        with pytest.raises(InvalidValueError):
-            f((np.asarray([0, 1.0, 0]) / 0).astype(config.floatX))
-        with pytest.raises(InvalidValueError):
-            f((np.asarray([1.0, 1.0, 1.0]) / 0).astype(config.floatX))
-
-        # generating an invalid value internally should trigger
-        # InvalidValueError
-        with pytest.raises(InvalidValueError):
-            g(np.asarray([3, -4, 5], dtype=config.floatX))
-
-        # this should disable the exception
-        TensorType.filter_checks_isfinite = False
-        predefined_modes["DEBUG_MODE"].check_isfinite = False
-        # insert several Inf
-        f(np.asarray(np.asarray([1.0, 1.0, 1.0]) / 0, dtype=config.floatX))
-
-    def test_check_isfinite_disabled(self):
-        x = dvector()
-        f = function([x], (x + 2) * 5, mode=DebugMode(check_isfinite=False))
-
-        # nan should go through
-        f(np.log([3, -4, 5]))
-
-        # inf should go through
-        infs = np.asarray([1.0, 1.0, 1.0]) / 0
-        # print infs
-        f(infs)
-        return
-
-
 class BrokenCImplementationAdd(COp):
     __props__ = ()
 
@@ -804,20 +749,6 @@ def test_output_broadcast_tensor(self):
         f(v_val)
 
 
-def test_function_dict():
-    """Tests that debug mode works where outputs is a dictionary."""
-
-    x = scalar("x")
-
-    f = function([x], outputs={"1": x, "2": 2 * x, "3": 3 * x}, mode="DEBUG_MODE")
-
-    result = f(3.0)
-
-    assert result["1"] == 3.0
-    assert result["2"] == 6.0
-    assert result["3"] == 9.0
-
-
 def test_function_list():
     """Tests that debug mode works where the outputs argument is a list."""
 
diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py
index dad7ed4fdd..37aada17c4 100644
--- a/tests/link/test_vm.py
+++ b/tests/link/test_vm.py
@@ -6,7 +6,6 @@
 from pytensor.compile.function import function
 from pytensor.compile.io import In
 from pytensor.compile.mode import Mode, get_mode
-from pytensor.compile.sharedvalue import shared
 from pytensor.configdefaults import config
 from pytensor.graph.basic import Apply
 from pytensor.graph.fg import FunctionGraph
@@ -17,9 +16,8 @@
 from pytensor.link.utils import map_storage
 from pytensor.link.vm import VM, Loop, Stack, VMLinker
 from pytensor.tensor.math import cosh, tanh
-from pytensor.tensor.type import lscalar, scalar, scalars, vector, vectors
+from pytensor.tensor.type import scalar, scalars, vector, vectors
 from pytensor.tensor.variable import TensorConstant
-from tests import unittest_tools as utt
 
 
 class SomeOp(Op):
@@ -202,71 +200,6 @@ def build_graph(x, depth=5):
     # print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
 
 
-@pytest.mark.parametrize(
-    "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
-)
-def test_partial_function(linker):
-    x = scalar("input")
-    y = x**2
-    f = function(
-        [x], [y + 7, y - 9, y / 14.0], mode=Mode(optimizer=None, linker=linker)
-    )
-
-    if linker == "cvm":
-        from pytensor.link.c.cvm import CVM
-
-        assert isinstance(f.vm, CVM)
-    else:
-        assert isinstance(f.vm, Stack)
-
-    assert f(3, output_subset=[0, 1, 2]) == f(3)
-    assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
-
-    utt.assert_allclose(f(5), np.array([32.0, 16.0, 1.7857142857142858]))
-
-
-@pytest.mark.parametrize(
-    "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
-)
-def test_partial_function_with_output_keys(linker):
-    x = scalar("input")
-    y = 3 * x
-    f = function(
-        [x], {"a": y * 5, "b": y - 7}, mode=Mode(optimizer=None, linker=linker)
-    )
-
-    assert f(5, output_subset=["a"])["a"] == f(5)["a"]
-
-
-@pytest.mark.parametrize(
-    "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
-)
-def test_partial_function_with_updates(linker):
-    x = lscalar("input")
-    y = shared(np.asarray(1, "int64"), name="global")
-
-    mode = Mode(optimizer=None, linker=linker)
-
-    f = function(
-        [x],
-        [x, x + 34],
-        updates=[(y, x + 1)],
-        mode=mode,
-    )
-    g = function(
-        [x],
-        [x - 6],
-        updates=[(y, y + 3)],
-        mode=mode,
-    )
-
-    assert f(3, output_subset=[]) == []
-    assert y.get_value() == 4
-    assert g(30, output_subset=[0]) == [24]
-    assert g(40, output_subset=[]) == []
-    assert y.get_value() == 10
-
-
 def test_allow_gc_cvm():
     mode = config.mode
     if mode in ["DEBUG_MODE", "DebugMode"]: