Skip to content

Implement predict_proba for RandomForestClassifier #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: development
Choose a base branch
from
Draft
12 changes: 12 additions & 0 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
* **PRs on develop**: any change should be PRed first in `development`

* **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.


## Suggestions for debugging
1. Install `lldb` for your platform
2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line
3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber`
4. In lldb, run a single test with `r the_name_of_your_test`

Display variables in scope: `frame variable <name>`

Execute expression: `p <expr>`

1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ jobs:
]
env:
TZ: "/usr/share/zoneinfo/your/location"
RUST_BACKTRACE: "1"
steps:
- uses: actions/checkout@v3
- name: Cache .cargo and target
162 changes: 162 additions & 0 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
@@ -55,7 +55,9 @@ use serde::{Deserialize, Serialize};

use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::MutArray;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;

@@ -602,11 +604,76 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
samples
}

/// Predict class probabilities for X.
///
/// The predicted class probabilities of an input sample are computed as
/// the mean predicted class probabilities of the trees in the forest.
/// The class probability of a single tree is the fraction of samples of
/// the same class in a leaf.
///
/// # Arguments
///
/// * `x` - The input samples. A matrix of shape (n_samples, n_features).
///
/// # Returns
///
/// * `Result<DenseMatrix<f64>, Failed>` - The class probabilities of the input samples.
/// The order of the classes corresponds to that in the attribute `classes_`.
/// The matrix has shape (n_samples, n_classes).
///
/// # Errors
///
/// Returns a `Failed` error if:
/// * The model has not been fitted yet.
/// * The input `x` is not compatible with the model's expected input.
/// * Any of the tree predictions fail.
///
/// # Examples
///
/// ```
/// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::basic::arrays::Array;
///
/// let x = DenseMatrix::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
/// &[4.9, 3.0, 1.4, 0.2],
/// &[7.0, 3.2, 4.7, 1.4],
/// ]).unwrap();
/// let y = vec![0, 0, 1];
///
/// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
/// let probas = forest.predict_proba(&x).unwrap();
///
/// assert_eq!(probas.shape(), (3, 2));
/// ```
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
let (n_samples, _) = x.shape();
let n_classes = self.classes.as_ref().unwrap().len();
let mut probas = DenseMatrix::<f64>::zeros(n_samples, n_classes);

for tree in self.trees.as_ref().unwrap().iter() {
let tree_predictions: Y = tree.predict(x).unwrap();

for (i, &class_idx) in tree_predictions.iterator(0).enumerate() {
let class_ = class_idx.to_usize().unwrap();
probas.add_element_mut((i, class_), 1.0);
}
}

let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64;
probas.mul_scalar_mut(1.0 / n_trees);

Ok(probas)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::random_forest_classifier::RandomForestClassifier;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;

@@ -760,6 +827,101 @@ mod tests {
);
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_random_forest_predict_proba() {
use num_traits::FromPrimitive;
// Iris-like dataset (subset)
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];

let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
let probas = forest.predict_proba(&x).unwrap();

// Test shape
assert_eq!(probas.shape(), (10, 2));

let (pro_n_rows, _) = probas.shape();

// Test probability sum
for i in 0..pro_n_rows {
let row_sum: f64 = probas.get_row(i).sum();
assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row probabilities should sum to 1"
);
}

// Test class prediction
let predictions: Vec<u32> = (0..pro_n_rows)
.map(|i| {
if probas.get((i, 0)) > probas.get((i, 1)) {
0
} else {
1
}
})
.collect();
let acc = accuracy(&y, &predictions);
assert!(acc > 0.8, "Accuracy should be high for the training set");

// Test probability values
// These values are approximate and based on typical random forest behavior
for i in 0..(pro_n_rows / 2) {
assert!(
f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))),
"Class 0 samples should have high probability for class 0"
);
assert!(
f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))),
"Class 0 samples should have low probability for class 1"
);
}

for i in (pro_n_rows / 2)..pro_n_rows {
assert!(
f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))),
"Class 1 samples should have high probability for class 1"
);
assert!(
f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))),
"Class 1 samples should have low probability for class 0"
);
}

// Test with new data
let x_new = DenseMatrix::from_2d_array(&[
&[5.0, 3.4, 1.5, 0.2], // Should be close to class 0
&[6.3, 3.3, 4.7, 1.6], // Should be close to class 1
])
.unwrap();
let probas_new = forest.predict_proba(&x_new).unwrap();
assert_eq!(probas_new.shape(), (2, 2));
assert!(
probas_new.get((0, 0)) > probas_new.get((0, 1)),
"First sample should be predicted as class 0"
);
assert!(
probas_new.get((1, 1)) > probas_new.get((1, 0)),
"Second sample should be predicted as class 1"
);
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Loading