Skip to content

Commit

Permalink
feat: Add mmap support for reading sparse vectors to avoid OOM error …
Browse files Browse the repository at this point in the history
…in CI (#129)

* fix: Manual benchmarks

* fix: Remove gcs secrets

* feat: Use mmap to read sparse vectors

* fix: Format

* fix: Make unused var private
  • Loading branch information
KShivendu authored Apr 17, 2024
1 parent 04bbb7c commit a8bcf78
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dataset_reader/sparse_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,28 @@ def read_sparse_matrix_fields(
return values, columns, index_pointer


def mmap_sparse_matrix_fields(fname):
"""mmap the fields of a CSR matrix without instantiating it"""
with open(fname, "rb") as f:
sizes = np.fromfile(f, dtype="int64", count=3)
n_row, _n_col, n_non_zero = sizes
offset = sizes.nbytes
index_pointer = np.memmap(
fname, dtype="int64", mode="r", offset=offset, shape=n_row + 1
)
offset += index_pointer.nbytes
columns = np.memmap(fname, dtype="int32", mode="r", offset=offset, shape=n_non_zero)
offset += columns.nbytes
values = np.memmap(
fname, dtype="float32", mode="r", offset=offset, shape=n_non_zero
)
return values, columns, index_pointer


def csr_to_sparse_vectors(
values: List[float], columns: List[int], index_pointer: List[int]
) -> Iterator[SparseVector]:
"""Convert a CSR matrix to a list of SparseVectors"""
num_rows = len(index_pointer) - 1

for i in range(num_rows):
Expand All @@ -38,9 +57,12 @@ def csr_to_sparse_vectors(
yield SparseVector(indices=row_indices, values=row_values)


def read_csr_matrix(filename: Union[Path, str]) -> Iterator[SparseVector]:
def read_csr_matrix(filename: Union[Path, str], do_mmap=True) -> Iterator[SparseVector]:
"""Read a CSR matrix in spmat format"""
values, columns, index_pointer = read_sparse_matrix_fields(filename)
if do_mmap:
values, columns, index_pointer = mmap_sparse_matrix_fields(filename)
else:
values, columns, index_pointer = read_sparse_matrix_fields(filename)
values = values.tolist()
columns = columns.tolist()
index_pointer = index_pointer.tolist()
Expand Down

0 comments on commit a8bcf78

Please sign in to comment.