From a8bcf7829f6d7fd245e2f47072591c1aa7c8be28 Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Wed, 17 Apr 2024 17:32:35 +0530 Subject: [PATCH] feat: Add mmap support for reading sparse vectors to avoid OOM error in CI (#129) * fix: Manual benchmarks * fix: Remove gcs secrets * feat: Use mmap to read sparse vectors * fix: Format * fix: Make unused var private --- dataset_reader/sparse_reader.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/dataset_reader/sparse_reader.py b/dataset_reader/sparse_reader.py index fb2af5d9..94ee4167 100644 --- a/dataset_reader/sparse_reader.py +++ b/dataset_reader/sparse_reader.py @@ -23,9 +23,28 @@ def read_sparse_matrix_fields( return values, columns, index_pointer +def mmap_sparse_matrix_fields(fname): + """mmap the fields of a CSR matrix without instantiating it""" + with open(fname, "rb") as f: + sizes = np.fromfile(f, dtype="int64", count=3) + n_row, _n_col, n_non_zero = sizes + offset = sizes.nbytes + index_pointer = np.memmap( + fname, dtype="int64", mode="r", offset=offset, shape=n_row + 1 + ) + offset += index_pointer.nbytes + columns = np.memmap(fname, dtype="int32", mode="r", offset=offset, shape=n_non_zero) + offset += columns.nbytes + values = np.memmap( + fname, dtype="float32", mode="r", offset=offset, shape=n_non_zero + ) + return values, columns, index_pointer + + def csr_to_sparse_vectors( values: List[float], columns: List[int], index_pointer: List[int] ) -> Iterator[SparseVector]: + """Convert a CSR matrix to a list of SparseVectors""" num_rows = len(index_pointer) - 1 for i in range(num_rows): @@ -38,9 +57,12 @@ def csr_to_sparse_vectors( yield SparseVector(indices=row_indices, values=row_values) -def read_csr_matrix(filename: Union[Path, str]) -> Iterator[SparseVector]: +def read_csr_matrix(filename: Union[Path, str], do_mmap=True) -> Iterator[SparseVector]: """Read a CSR matrix in spmat format""" - values, columns, index_pointer = read_sparse_matrix_fields(filename) + if do_mmap: + values, columns, index_pointer = mmap_sparse_matrix_fields(filename) + else: + values, columns, index_pointer = read_sparse_matrix_fields(filename) values = values.tolist() columns = columns.tolist() index_pointer = index_pointer.tolist()