Skip to content

Commit

Permalink
add further concurrency to CSR construction
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmartinjr committed Sep 6, 2024
1 parent e5873a2 commit ce6426b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 35 deletions.
56 changes: 45 additions & 11 deletions other_packages/python/tiledbsoma_ml/src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import torch
import torchdata
from somacore.query._eager_iter import EagerIterator as _EagerIterator
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias

import tiledbsoma as soma

Expand Down Expand Up @@ -1043,6 +1043,11 @@ def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer:
)
return _CSR_IO_Buffer.from_pjd(indptr, indices, data, shape)

def sort_indices(self) -> Self:
"""Sort indices, IN PLACE."""
_csr_sort_indices(self.indptr, self.indices, self.data)
return self


def smallest_uint_dtype(max_val: int) -> npt.DTypeLike:
for dt in [np.uint16, np.uint32]:
Expand Down Expand Up @@ -1128,19 +1133,48 @@ def _coo_to_csr_inner(
cumsum += tmp
Bp[n_rows] = nnz

# reorganize all of the data. side-effect: pointers shifted.
for n in range(nnz):
row = Ai[n]
dst_row = Bp[row]

Bj[dst_row] = Aj[n]
Bd[dst_row] = Ad[n]

Bp[row] += 1
# Reorganize all of the data. Side-effect: pointers shifted (reversed in the
# subsequent section).
#
# Method is concurrent (partioned by rows) if number of rows is greater
# than 2**partition_bits. This partitioning scheme leverages the fact
# that reads are much cheaper than writes.
#
# The code is equivalent to:
# for n in range(nnz):
# row = Ai[n]
# dst_row = Bp[row]
# Bj[dst_row] = Aj[n]
# Bd[dst_row] = Ad[n]
# Bp[row] += 1

partition_bits = 13
n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits
for p in numba.prange(n_partitions):
for n in range(nnz):
row = Ai[n]
if (row >> partition_bits) != p:
continue
dst_row = Bp[row]
Bj[dst_row] = Aj[n]
Bd[dst_row] = Ad[n]
Bp[row] += 1

# and shift the pointers by one (ie., start at zero)
# Shift the pointers by one slot (ie., start at zero)
prev_ptr = 0
for n in range(n_rows + 1):
tmp = Bp[n]
Bp[n] = prev_ptr
prev_ptr = tmp


@numba.njit(nogil=True, parallel=True) # type:ignore[misc]
def _csr_sort_indices(Bp: NDArrayNumber, Bj: NDArrayNumber, Bd: NDArrayNumber) -> None:
"""In-place sort of minor axis indices"""
n_rows = len(Bp) - 1
for r in numba.prange(n_rows):
row_start = Bp[r]
row_end = Bp[r + 1]
order = np.argsort(Bj[row_start:row_end])
Bj[row_start:row_end] = Bj[row_start:row_end][order]
Bd[row_start:row_end] = Bd[row_start:row_end][order]
49 changes: 25 additions & 24 deletions other_packages/python/tiledbsoma_ml/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,26 +825,6 @@ def test_splits() -> None:
_splits(10, -1)


# temp comment out while building _CSR tests
# def test_csr_to_dense() -> None:
# from tiledbsoma_ml.pytorch import _csr_to_dense

# coo = sparse.eye(1001, 77, format="coo", dtype=np.float32)

# assert np.array_equal(
# sparse.csr_array(coo).todense(), _csr_to_dense(sparse.csr_array(coo))
# )
# assert np.array_equal(
# sparse.csr_matrix(coo).todense(), _csr_to_dense(sparse.csr_matrix(coo))
# )

# csr = sparse.csr_array(coo)
# assert np.array_equal(csr.todense(), _csr_to_dense(csr))
# assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :]))
# assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:]))
# assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22]))


@pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray
"shape",
[(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)],
Expand All @@ -865,8 +845,8 @@ def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -
_ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes
)

# _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie.., "canonical" form), so
# use the SciPy sparse csr package to validate by round-tripping.
# _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until
# sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping.
assert (
sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape)
!= sp_csr
Expand Down Expand Up @@ -896,8 +876,8 @@ def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -
shape=sp_csr.shape,
)

# _CSR makes no guarantees about minor axis ordering (ie.., "canonical" form), so
# use the SciPy sparse csr package to validate by round-tripping.
# _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until
# sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping.
assert (
sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape)
!= sp_csr
Expand Down Expand Up @@ -939,3 +919,24 @@ def test_csr__merge(
(_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape
)
).nnz == 0


@pytest.mark.parametrize(
"shape",
[(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)],
)
def test_csr__sort_indices(shape: Tuple[int, int]) -> None:
from tiledbsoma_ml.pytorch import _CSR_IO_Buffer

sp_coo = sparse.random(
shape[0], shape[1], dtype=np.float32, format="coo", density=0.05
)
sp_csr = sp_coo.tocsr()

_ncsr = _CSR_IO_Buffer.from_ijd(
sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape
).sort_indices()

assert np.array_equal(sp_csr.indptr, _ncsr.indptr)
assert np.array_equal(sp_csr.indices, _ncsr.indices)
assert np.array_equal(sp_csr.data, _ncsr.data)

0 comments on commit ce6426b

Please sign in to comment.