Skip to content

Commit

Permalink
Merge pull request #7 from retraigo/ndarray
Browse files Browse the repository at this point in the history
feat: migrate to ndarray + improve performance
  • Loading branch information
retraigo committed Apr 1, 2024
2 parents 0efa1a1 + ef6b74a commit 497c669
Show file tree
Hide file tree
Showing 36 changed files with 419 additions and 423 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@
/ignore
/test
/rust/target
.vscode/
.vscode/
test.ts
iris.csv
bench/nalgebra.ts
bench/ndarray.ts
bench/netsaur.ts
148 changes: 29 additions & 119 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ const solver = new GradientDescentSolver({
solver.train(
new Matrix(
x.map((n) => [n]),
"f64"
"f32"
),
new Matrix(y, "f64"),
new Matrix(y, "f32"),
{ silent: false, fit_intercept: true, epochs: 700, n_batches: 2 }
);

const res = solver.predict(
new Matrix(
x.map((n) => [n]),
"f64"
"f32"
)
);

Expand Down
5 changes: 3 additions & 2 deletions crates/classy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
nalgebra = "0.32.3"
rand = "0.8.5"
rand = "0.8.5"
ndarray = "0.15"
ndarray-rand = "0.14"
16 changes: 8 additions & 8 deletions crates/classy/src/core/activation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use nalgebra::DMatrix;
use ndarray::{Array2, Axis};

pub enum Activation {
Linear,
Expand All @@ -8,24 +8,24 @@ pub enum Activation {
}

impl Activation {
pub fn call_on_all(&self, h: DMatrix<f64>) -> DMatrix<f64> {
pub fn call_on_all(&self, h: Array2<f32>) -> Array2<f32> {
match self {
Self::Linear => h,
Self::Sigmoid => h.map(|x| sigmoid(x)),
Self::Sigmoid => h.map(|x| sigmoid(*x)),
Self::Tanh => h.map(|x| x.tanh()),
Self::Softmax => {
let mut res = DMatrix::zeros(h.nrows(), h.ncols());
for (mut res_row, h_row) in res.row_iter_mut().zip(h.row_iter()) {
let mut res = Array2::zeros((h.nrows(), h.ncols()));
for (mut res_row, h_row) in res.axis_iter_mut(Axis(0)).zip(h.axis_iter(Axis(0))) {
let exp_values = h_row.map(|v| v.exp());
let sum_exp: f64 = exp_values.iter().sum();
res_row.copy_from(&(exp_values / sum_exp));
let sum_exp: f32 = exp_values.sum();
res_row.assign(&(exp_values / sum_exp));
}
res
}
}
}
}

fn sigmoid(x: f64) -> f64 {
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
4 changes: 2 additions & 2 deletions crates/classy/src/core/ffi/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub unsafe extern "C" fn hinge_loss() -> isize {
}

#[no_mangle]
pub unsafe extern "C" fn huber_loss(delta: f64) -> isize {
pub unsafe extern "C" fn huber_loss(delta: f32) -> isize {
std::mem::transmute::<Box<LossFunction>, isize>(std::boxed::Box::new(LossFunction::Huber(delta)))
}

Expand All @@ -36,6 +36,6 @@ pub unsafe extern "C" fn smooth_hinge_loss() -> isize {
}

#[no_mangle]
pub unsafe extern "C" fn tukey_loss(c: f64) -> isize {
pub unsafe extern "C" fn tukey_loss(c: f32) -> isize {
std::mem::transmute::<Box<LossFunction>, isize>(std::boxed::Box::new(LossFunction::Tukey(c)))
}
10 changes: 5 additions & 5 deletions crates/classy/src/core/ffi/optimizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::core::optimizers::{OptimizerConfig, Optimizer};

#[no_mangle]
pub unsafe extern "C" fn adam_optimizer(
beta1: f64,
beta2: f64,
epsilon: f64,
beta1: f32,
beta2: f32,
epsilon: f32,
input_size: usize,
output_size: usize,
) -> isize {
Expand All @@ -15,8 +15,8 @@ pub unsafe extern "C" fn adam_optimizer(

#[no_mangle]
pub unsafe extern "C" fn rmsprop_optimizer(
decay_rate: f64,
epsilon: f64,
decay_rate: f32,
epsilon: f32,
input_size: usize,
output_size: usize,
) -> isize {
Expand Down
2 changes: 1 addition & 1 deletion crates/classy/src/core/ffi/regularizer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::core::regularization::Regularization;

#[no_mangle]
pub unsafe extern "C" fn regularizer(c: f64, l1_ratio: f64) -> isize {
pub unsafe extern "C" fn regularizer(c: f32, l1_ratio: f32) -> isize {
let reg = Regularization::from(c, l1_ratio);
std::mem::transmute::<Box<Regularization>, isize>(std::boxed::Box::new(reg))
}
6 changes: 3 additions & 3 deletions crates/classy/src/core/ffi/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use crate::core::scheduler::Scheduler;

#[no_mangle]
pub unsafe extern "C" fn linear_decay_scheduler(rate: f64, step_size: usize) -> isize {
pub unsafe extern "C" fn linear_decay_scheduler(rate: f32, step_size: usize) -> isize {
std::mem::transmute::<Box<Scheduler>, isize>(std::boxed::Box::new(Scheduler::LinearDecay {
rate,
step_size,
}))
}

#[no_mangle]
pub unsafe extern "C" fn exponential_decay_scheduler(rate: f64, step_size: usize) -> isize {
pub unsafe extern "C" fn exponential_decay_scheduler(rate: f32, step_size: usize) -> isize {
std::mem::transmute::<Box<Scheduler>, isize>(std::boxed::Box::new(
Scheduler::ExponentialDecay { rate, step_size },
))
}

#[no_mangle]
pub unsafe extern "C" fn one_cycle_scheduler(max_rate: f64, step_size: usize) -> isize {
pub unsafe extern "C" fn one_cycle_scheduler(max_rate: f32, step_size: usize) -> isize {
std::mem::transmute::<Box<Scheduler>, isize>(std::boxed::Box::new(Scheduler::OneCycle {
max_rate,
step_size,
Expand Down
Loading

0 comments on commit 497c669

Please sign in to comment.