-
Notifications
You must be signed in to change notification settings - Fork 153
Description
Description
from pytensor.graph import FunctionGraph
import pytensor.scalar as ps
from pytensor.link.c.basic import CLinker
x = ps.float64("x")
o = x + 1
cl = CLinker().accept(FunctionGraph([x], [o]))
cl.cmodule_key()
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt
from pytensor.link.c.basic import CLinker
x = pt.vector("x")
y = pt.vector("y")
z = pt.vector("z")
c1 = pt.constant([1, 1, 1, 1, 1])
c2 = pt.constant([1, 1, 1, 1, 2])
# Same source code
assert (
CLinker().accept(FunctionGraph([x, y], [x + y])).get_src_code()
== CLinker().accept(FunctionGraph([x, z], [x + z])).get_src_code()
)
# Same source code
assert (
CLinker().accept(FunctionGraph([x], [x + c1])).get_src_code()
== CLinker().accept(FunctionGraph([x], [x + c2])).get_src_code()
)
# Same hashing
assert (
CLinker().accept(FunctionGraph([x, y], [x + y])).cmodule_key()
== CLinker().accept(FunctionGraph([x, z], [x + z])).cmodule_key()
)
# Distinct hashing
assert (
CLinker().accept(FunctionGraph([x], [x + c1])).cmodule_key()
== CLinker().accept(FunctionGraph([x], [x + c2])).cmodule_key()
)This is wasteful. It's common to have Ops with the same code but distinct constant inputs. The only case where constants matter are:
- Scalar constants which have a
c_literalthat is used directly in the generated code. It should be easy to add a work-around for this - Operations that use node constant info to specialize the C-code. This is something that was never done before, but I started doing for Join/Split (most times axis is constant), and for AdvancedSubtensor1/AdvancedIncSubtensor1, to decide whether we need to check for negative / invalid indices.
This information could/should be encoded in c_code_cache_version_apply which can use node information. If we handle those cases we can reuse much more compiled functions
The integration of constants in the cmodule_key happens here:
pytensor/pytensor/link/c/basic.py
Lines 1395 to 1412 in 27c21cd
| if isinstance(i, AtomicVariable): # orphans | |
| if id(i) not in constant_ids: | |
| isig = (i.signature(), topological_pos, i_idx) | |
| # If the PyTensor constant provides a strong hash | |
| # (no collision for transpose, 2, 1, 0, -1, -2, | |
| # 2 element swapped...) we put this hash in the signature | |
| # instead of the value. This makes the key file much | |
| # smaller for big constant arrays. Before this, we saw key | |
| # files up to 80M. | |
| if hasattr(isig[0], "pytensor_hash"): | |
| isig = (isig[0].pytensor_hash(), topological_pos, i_idx) | |
| try: | |
| hash(isig) | |
| except Exception: | |
| # generic constants don't have a hashable signature | |
| error_on_play[0] = True | |
| return None | |
| constant_ids[id(i)] = isig |
The introduction of scalar literals (only if they are not fgraph inputs:
pytensor/pytensor/link/c/basic.py
Lines 624 to 630 in 27c21cd
| if ( | |
| isinstance(variable, Constant) | |
| and isinstance(variable.type, CLinkerType) | |
| # This apparently checks--in a somewhat "dynamic" way--whether | |
| # or not a literal value is available (in C). | |
| and variable.type.c_literal(variable.data) | |
| ): |
pytensor/pytensor/link/c/basic.py
Lines 693 to 695 in 27c21cd
| if variable in self.consts: | |
| symbol[variable] = "(" + variable.type.c_literal(variable.data) + ")" | |
| continue |
Exploitation by the C-code of an Op:
pytensor/pytensor/tensor/basic.py
Lines 2331 to 2349 in 27c21cd
| # Most times axis is constant, inline it | |
| # This is safe to do because the hash of the c_code includes the constant signature | |
| if isinstance(node.inputs[1], Constant): | |
| static_axis = int(node.inputs[1].data) | |
| static_axis = normalize_axis_index(static_axis, ndim) | |
| axis_def = f"{static_axis};" | |
| axis_check = "" | |
| else: | |
| axis_dtype = node.inputs[1].type.dtype_specs()[1] | |
| axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];" | |
| axis_check = f""" | |
| if (axis < 0){{ | |
| axis = ndim + axis; | |
| }} | |
| if (axis >= ndim || axis < 0) {{ | |
| PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds"); | |
| {fail} | |
| }} | |
| """ |