Skip to content

Commit 97d52e7

Browse files
committed
fix ownership
1 parent 5ec7b2f commit 97d52e7

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ approx = "0.4"
4242

4343
ndarray = { version = "0.15", features = ["approx"] }
4444
ndarray-linalg = { version = "0.16", optional = true }
45-
sprs = { version = "0.11", default-features = false }
45+
sprs = { version = "=0.11.1", default-features = false }
4646

4747
thiserror = "1.0"
4848

algorithms/linfa-kernel/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ features = ["std", "derive"]
2626
[dependencies]
2727
ndarray = "0.15"
2828
num-traits = "0.2"
29-
sprs = { version="0.11", default-features = false }
29+
sprs = { version="0.11.1", default-features = false }
3030

3131
linfa = { version = "0.7.0", path = "../.." }
3232
linfa-nn = { version = "0.7.0", path = "../linfa-nn" }

algorithms/linfa-kernel/src/inner.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,19 @@ impl<F: Float> Inner for CsMat<F> {
6161
type Elem = F;
6262

6363
fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
64-
self.mul(&rhs.to_owned())
64+
let mut result = Array2::zeros((self.rows(), rhs.ncols()));
65+
66+
// Handle potential sparse matrices
67+
for j in 0..rhs.ncols() {
68+
let col = rhs.column(j);
69+
let col_result = self.mul(&col.to_owned());
70+
// Copy result into appropriate column of output
71+
for i in 0..self.rows() {
72+
result[[i, j]] = col_result[i];
73+
}
74+
}
75+
76+
result
6577
}
6678
fn sum(&self) -> Array1<F> {
6779
let mut sum = Array1::zeros(self.cols());
@@ -106,7 +118,19 @@ impl<'a, F: Float> Inner for CsMatView<'a, F> {
106118
type Elem = F;
107119

108120
fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
109-
self.mul(&rhs.to_owned())
121+
let mut result = Array2::zeros((self.rows(), rhs.ncols()));
122+
123+
// Handle potential sparse matrices
124+
for j in 0..rhs.ncols() {
125+
let col = rhs.column(j);
126+
let col_result = self.mul(&col.to_owned());
127+
// Copy result into appropriate column of output
128+
for i in 0..self.rows() {
129+
result[[i, j]] = col_result[i];
130+
}
131+
}
132+
133+
result
110134
}
111135
fn sum(&self) -> Array1<F> {
112136
let mut sum = Array1::zeros(self.cols());

algorithms/linfa-preprocessing/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ ndarray-rand = { version = "0.14" }
2929
unicode-normalization = "0.1.8"
3030
regex = "1.4.5"
3131
encoding = "0.2"
32-
sprs = { version = "0.11.0", default-features = false }
32+
sprs = { version = "0.11.1", default-features = false }
3333

3434
serde_regex = { version = "1.1", optional = true }
3535

0 commit comments

Comments
 (0)