Skip to content

Canonicalize Dot as a matrix-matrix operation #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node):
| pytensor.tensor.blas.Gemv
| pytensor.tensor.blas_c.CGemv
| pytensor.tensor.blas.Ger
| pytensor.tensor.blas_c.CGer
| pytensor.tensor.blas_scipy.ScipyGer,
| pytensor.tensor.blas_c.CGer,
)
):
# Ops that will work inplace on the Alloc. So if they
Expand Down
32 changes: 11 additions & 21 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from pathlib import Path

import numpy as np
from scipy.linalg import get_blas_funcs

from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
Expand Down Expand Up @@ -288,18 +289,15 @@ def make_node(self, A, alpha, x, y):

return Apply(self, inputs, [A.type()])

def perform(self, node, inp, out):
cA, calpha, cx, cy = inp
(cZ,) = out
if self.destructive:
A = cA
else:
A = cA.copy()
if calpha != 1:
A += calpha * np.outer(cx, cy)
def perform(self, node, inputs, output_storage):
A, alpha, x, y = inputs
ger_func = get_blas_funcs("ger", dtype=A.dtype)
if A.flags["C_CONTIGUOUS"]:
# Work on transposed system to avoid copying
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
else:
A += np.outer(cx, cy)
cZ[0] = A
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
output_storage[0][0] = A

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
Expand Down Expand Up @@ -1128,16 +1126,8 @@ def make_node(self, x, y):
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs)

def perform(self, node, inp, out):
x, y = inp
(z,) = out
try:
z[0] = np.asarray(np.dot(x, y))
except ValueError as e:
# The error raised by numpy has no shape information, we mean to
# add that
e.args = (*e.args, x.shape, y.shape)
raise
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.dot(*inputs)

def infer_shape(self, fgraph, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
Expand Down
23 changes: 7 additions & 16 deletions pytensor/tensor/blas_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,22 @@
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""

from scipy.linalg.blas import get_blas_funcs

from pytensor.tensor.blas import Ger


class ScipyGer(Ger):
def perform(self, node, inputs, output_storage):
from scipy.linalg.blas import get_blas_funcs

cA, calpha, cx, cy = inputs
(cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, which is confusing.
if not self.destructive:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A = A.copy()
elif A.flags["C_CONTIGUOUS"]:
A = local_ger(calpha, cy, cx, a=A.T, overwrite_a=int(self.destructive)).T
ger_func = get_blas_funcs("ger", dtype=cA.dtype)
if A.flags["C_CONTIGUOUS"]:
# Work on transposed system to avoid copying
A = ger_func(calpha, cy, cx, a=A.T, overwrite_a=self.destructive).T
else:
A = local_ger(calpha, cx, cy, a=A, overwrite_a=int(self.destructive))
A = ger_func(calpha, cx, cy, a=A, overwrite_a=self.destructive)
cZ[0] = A


Expand Down
Loading
Loading