-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Description
Description
For Ops with a gufunc_signature, we can automate infer_shape
implementation:
pytensor/pytensor/tensor/slinalg.py
Lines 28 to 53 in e036caf
class Cholesky(Op): | |
# TODO: LAPACK wrapper with in-place behavior, for solve also | |
__props__ = ("lower", "check_finite", "on_error", "overwrite_a") | |
gufunc_signature = "(m,m)->(m,m)" | |
def __init__( | |
self, | |
*, | |
lower: bool = True, | |
check_finite: bool = True, | |
on_error: Literal["raise", "nan"] = "raise", | |
overwrite_a: bool = False, | |
): | |
self.lower = lower | |
self.check_finite = check_finite | |
if on_error not in ("raise", "nan"): | |
raise ValueError('on_error must be one of "raise" or ""nan"') | |
self.on_error = on_error | |
self.overwrite_a = overwrite_a | |
if self.overwrite_a: | |
self.destroy_map = {0: [0]} | |
def infer_shape(self, fgraph, node, shapes): | |
return [shapes[0]] |
We actually already do it for the Blockwise Wrapper:
pytensor/pytensor/tensor/blockwise.py
Lines 208 to 210 in 3cdcfde
# The output dim is the same as another input dim | |
if dim_name in core_dims: | |
core_out_shape.append(core_dims[dim_name]) |