Skip to content

Implement BandedDot Op #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7904256
Naive implementation, do not merge
jessegrabowski May 23, 2025
db5b23c
Implement suggestions
jessegrabowski May 23, 2025
c687856
Simplify perf test
jessegrabowski May 23, 2025
4db2a33
float32 compat in tests
jessegrabowski May 23, 2025
3504f0b
Remove np.pad
jessegrabowski May 23, 2025
1bcf463
set dtype correctly
jessegrabowski May 23, 2025
0ce2cae
fix signature, add infer_shape
jessegrabowski May 23, 2025
161e172
micro-optimizations
jessegrabowski May 23, 2025
1ddd529
Rename b to x, matching BLAS docs
jessegrabowski May 24, 2025
b16189e
Add numba dispatch for banded_dot
jessegrabowski May 24, 2025
a902694
Eliminate extra copy in numba impl
jessegrabowski May 24, 2025
6becc7d
Create `A_banded` as F-contiguous array
jessegrabowski May 24, 2025
22578f3
Remove benchmark
jessegrabowski May 24, 2025
65c485e
Don't cache numba function
jessegrabowski May 24, 2025
905fc7c
all hail mypy
jessegrabowski May 24, 2025
687877c
set INCX by strides
jessegrabowski May 24, 2025
62ccf13
relax tolerance of float32 test
jessegrabowski May 24, 2025
8d30a29
Add suggestions
jessegrabowski May 25, 2025
e3d0b14
Test strides
jessegrabowski May 25, 2025
21873a9
Add L_op
jessegrabowski May 25, 2025
c1b6e01
*remove* type hints to make mypy happy
jessegrabowski May 25, 2025
e62b613
Remove order argument from numba A_to_banded
jessegrabowski May 25, 2025
025879a
Incorporate feedback
jessegrabowski May 25, 2025
beeec6a
Adjust numba test
jessegrabowski May 25, 2025
f467322
Remove more useful type information for mypy
jessegrabowski May 25, 2025
976422f
Fix negative strides
jessegrabowski Jun 10, 2025
72ba0dc
Rename `BandedDot` to `BandedGEMV` and move to `blas.py`
jessegrabowski Jun 24, 2025
eb50ca6
Add numba `gemv` overload
jessegrabowski Jun 24, 2025
976fd5b
All hail mypy
jessegrabowski Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic

import pytensor.link.numba.dispatch.blas

# isort: on
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor|banded_dot)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
Expand Down
59 changes: 59 additions & 0 deletions pytensor/link/numba/dispatch/blas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg.dot.banded import _gbmv
from pytensor.link.numba.dispatch.linalg.dot.general import _matrix_vector_product
from pytensor.link.numba.dispatch.slinalg import _COMPLEX_DTYPE_NOT_SUPPORTED_MSG
from pytensor.tensor.blas import BandedGEMV, Gemv
from pytensor.tensor.type import complex_dtypes


@numba_funcify.register(Gemv)
def numba_funcify_Gemv(op, node, **kwargs):
"""
Function to handle the Gemv operation in Numba.
"""
overwrite_y = op.inplace

@numba_njit()
def numba_gemv(y, alpha, A, x, beta):
"""
Numba implementation of the Gemv operation.
"""
return _matrix_vector_product(
alpha=alpha,
A=A,
x=x,
beta=beta,
y=y,
overwrite_y=overwrite_y,
)

return numba_gemv


@numba_funcify.register(BandedGEMV)
def numba_funcify_BandedGEMV(op, node, **kwargs):
kl = op.lower_diags
ku = op.upper_diags
overwrite_y = op.overwrite_y
trans = int(op.transpose)
dtype = node.inputs[0].dtype

if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))

@numba_njit(cache=False)
def banded_gemv(A, x, y, alpha, beta):
return _gbmv(
A=A,
x=x,
kl=kl,
ku=ku,
y=y,
alpha=alpha,
beta=beta,
overwrite_y=overwrite_y,
trans=trans,
)

return banded_gemv
93 changes: 93 additions & 0 deletions pytensor/link/numba/dispatch/linalg/_BLAS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import ctypes

from numba.core.extending import get_cython_function_address
from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_get_float_pointer_for_dtype,
_ptr_int,
)


def _get_blas_ptr_and_ptr_type(dtype, name):
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name)

return lapack_ptr, float_pointer


class _BLAS:
"""
Functions to return type signatures for wrapped BLAS functions.

Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy.

Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""

def __init__(self):
ensure_lapack()
ensure_blas()

@classmethod
def numba_xgemv(cls, dtype):
"""
xGEMV performs one of the following matrix operations:

y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y

Where alpha and beta are scalars, x and y are vectors, and A is a general matrix.
"""

blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gemv")

functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # M
_ptr_int, # N
float_pointer, # ALPHA
float_pointer, # A
_ptr_int, # LDA
float_pointer, # X
_ptr_int, # INCX
float_pointer, # BETA
float_pointer, # Y
_ptr_int, # INCY
)

return functype(blas_ptr)

@classmethod
def numba_xgbmv(cls, dtype):
"""
xGBMV performs one of the following matrix operations:

y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y

Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku
super-diagonals.
"""

blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv")

functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # TRANS
_ptr_int, # M
_ptr_int, # N
_ptr_int, # KL
_ptr_int, # KU
float_pointer, # ALPHA
float_pointer, # A
_ptr_int, # LDA
float_pointer, # X
_ptr_int, # INCX
float_pointer, # BETA
float_pointer, # Y
_ptr_int, # INCY
)

return functype(blas_ptr)
Empty file.
179 changes: 179 additions & 0 deletions pytensor/link/numba/dispatch/linalg/dot/banded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from collections.abc import Callable
from typing import Any

import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_blas, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_get_underlying_float,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_trans_char_to_int,
)


@numba_njit(inline="always")
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
m, n = A.shape

# This matrix is build backwards then transposed to get it into Fortran order
# (order="F" is not allowed in Numba land)
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T

for i, k in enumerate(range(ku, -kl - 1, -1)):
if k >= 0:
A_banded[i, k:] = np.diag(A, k=k)
else:
A_banded[i, : n + k] = np.diag(A, k=k)

return A_banded


def _gbmv(
alpha: np.ndarray,
A: np.ndarray,
x: np.ndarray,
kl: int,
ku: int,
beta: np.ndarray | None = None,
y: np.ndarray | None = None,
overwrite_y: bool = False,
trans: int = 1,
) -> Any:
"""
Thin wrapper around gmbv. This code will only be called if njit is disabled globally
(e.g. during testing)
"""
(fn,) = linalg.get_blas_funcs(("gbmv",), (A, x))
m, n = A.shape
A_banded = A_to_banded(A, kl=kl, ku=ku)

incx = x.strides[0] // x.itemsize
offx = 0 if incx >= 0 else -x.size + 1

if y is not None:
incy = y.strides[0] // y.itemsize
offy = 0 if incy >= 0 else -y.size + 1
else:
incy = 1
offy = 0

return fn(
m=m,
n=n,
kl=kl,
ku=ku,
a=A_banded,
alpha=alpha,
x=x,
incx=incx,
offx=offx,
beta=beta,
y=y,
overwrite_y=overwrite_y,
incy=incy,
offy=offy,
trans=trans,
)


@overload(_gbmv)
def gbmv_impl(
alpha: np.ndarray,
A: np.ndarray,
x: np.ndarray,
kl: int,
ku: int,
beta: np.ndarray | None = None,
y: np.ndarray | None = None,
overwrite_y: bool = False,
trans: int = 1,
) -> Callable[
[
np.ndarray,
np.ndarray,
np.ndarray,
int,
int,
np.ndarray | None,
np.ndarray | None,
bool,
int,
],
np.ndarray,
]:
ensure_lapack()
ensure_blas()
_check_scipy_linalg_matrix(A, "dot_banded")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gbmv = _BLAS().numba_xgbmv(dtype)

def impl(
alpha: np.ndarray,
A: np.ndarray,
x: np.ndarray,
kl: int,
ku: int,
beta: np.ndarray | None = None,
y: np.ndarray | None = None,
overwrite_y: bool = False,
trans: int = 1,
) -> np.ndarray:
m, n = A.shape

A_banded = A_to_banded(A, kl=kl, ku=ku)
x_stride = x.strides[0] // x.itemsize

if beta is None:
beta = np.zeros((), dtype=dtype)

if y is None:
y_copy = np.empty(shape=(m,), dtype=dtype)
elif overwrite_y and y.flags.f_contiguous:
y_copy = y
else:
y_copy = _copy_to_fortran_order_even_if_1d(y)

y_stride = y_copy.strides[0] // y_copy.itemsize

TRANS = val_to_int_ptr(_trans_char_to_int(trans))
M = val_to_int_ptr(m)
N = val_to_int_ptr(n)
LDA = val_to_int_ptr(A_banded.shape[0])

KL = val_to_int_ptr(kl)
KU = val_to_int_ptr(ku)

INCX = val_to_int_ptr(x_stride)
INCY = val_to_int_ptr(y_stride)

numba_gbmv(
TRANS,
M,
N,
KL,
KU,
alpha.view(w_type).ctypes,
A_banded.view(w_type).ctypes,
LDA,
# x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have
# a negative stride, we need to trick BLAS by pointing to the last element of the array.
# The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes)
(x if x_stride >= 0 else x[-1:]).view(w_type).ctypes,
INCX,
beta.view(w_type).ctypes,
y_copy.view(w_type).ctypes,
INCY,
)

return y_copy

return impl
Loading
Loading