Skip to content

Commit d67998a

Browse files
YuhanLiinPABannier
andauthored
Split AsTargets into 1D and 2D variants (rust-ml#206)
* Finish refactoring most of linfa main crate * Fix cross_validate and fix errors * Make split_ratio panic when input is not row-major * Fix external uses of single_targets * Fix algorithm examples and dataset * Fix cross_validation * Fix SVM * Add all-targets flag to checking step * Fix doc tests * Remove multitarget error variant * Document breaking changes in changelog Co-authored-by: Pierre-Antoine Bannier <[email protected]>
1 parent e406403 commit d67998a

File tree

27 files changed

+520
-610
lines changed

27 files changed

+520
-610
lines changed

.github/workflows/checking.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ jobs:
4040
uses: actions-rs/cargo@v1
4141
with:
4242
command: check
43-
args: --workspace
43+
args: --workspace --all-targets
4444

4545
- name: Run cargo check (with serde)
4646
if: ${{ matrix.toolchain != '1.54.0' }} # The following args don't work on older versions of Cargo
4747
uses: actions-rs/cargo@v1
4848
with:
4949
command: check
50-
args: --workspace --features "linfa-clustering/serde linfa-ica/serde linfa-kernel/serde linfa-reduction/serde linfa-svm/serde linfa-elasticnet/serde linfa-pls/serde linfa-trees/serde linfa-nn/serde"
50+
args: --workspace --all-targets --features "linfa-clustering/serde linfa-ica/serde linfa-kernel/serde linfa-reduction/serde linfa-svm/serde linfa-elasticnet/serde linfa-pls/serde linfa-trees/serde linfa-nn/serde"

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
Unreleased
2+
========================
3+
4+
Breaking Changes
5+
----------------------
6+
* parametrize `AsTargets` by the dimensionality of the targets and introduce `AsSingleTargets` and `AsMultiTargets`
7+
* 1D target arrays are no longer converted to 2D when constructing `Dataset`s
8+
* `Dataset` and `DatasetView` can now be parametrized by target dimensionality, with 2D being the default
9+
* single-target algorithms no longer accept 2D target arrays as input
10+
* `cross_validate_multi` has been merged with `cross_validate`, which is now generic across single and multi-targets
11+
112
Version 0.5.1 - 2022-02-28
213
========================
314

algorithms/linfa-bayes/src/gaussian_nb.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use linfa::dataset::{AsTargets, DatasetBase, Labels};
1+
use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
22
use linfa::traits::{Fit, FitWith, PredictInplace};
33
use linfa::{Float, Label};
44
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
@@ -14,7 +14,7 @@ where
1414
F: Float,
1515
L: Label + 'a,
1616
D: Data<Elem = F>,
17-
T: AsTargets<Elem = L> + Labels<Elem = L>,
17+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
1818
{
1919
}
2020

@@ -23,7 +23,7 @@ where
2323
F: Float,
2424
L: Label + Ord,
2525
D: Data<Elem = F>,
26-
T: AsTargets<Elem = L> + Labels<Elem = L>,
26+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
2727
{
2828
type Object = GaussianNb<F, L>;
2929

@@ -40,7 +40,7 @@ where
4040
F: Float,
4141
L: Label + 'a,
4242
D: Data<Elem = F>,
43-
T: AsTargets<Elem = L> + Labels<Elem = L>,
43+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
4444
{
4545
type ObjectIn = Option<GaussianNb<F, L>>;
4646
type ObjectOut = Option<GaussianNb<F, L>>;
@@ -51,7 +51,7 @@ where
5151
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
5252
) -> Result<Self::ObjectOut> {
5353
let x = dataset.records();
54-
let y = dataset.try_single_target()?;
54+
let y = dataset.as_single_targets();
5555

5656
// If the ratio of the variance between dimensions is too small, it will cause
5757
// numerical errors. We address this by artificially boosting the variance

algorithms/linfa-bayes/src/multinomial_nb.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use linfa::dataset::{AsTargets, DatasetBase, Labels};
1+
use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
22
use linfa::traits::{Fit, FitWith, PredictInplace};
33
use linfa::{Float, Label};
44
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
@@ -13,7 +13,7 @@ where
1313
F: Float,
1414
L: Label + 'a,
1515
D: Data<Elem = F>,
16-
T: AsTargets<Elem = L> + Labels<Elem = L>,
16+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
1717
{
1818
}
1919

@@ -22,7 +22,7 @@ where
2222
F: Float,
2323
L: Label + Ord,
2424
D: Data<Elem = F>,
25-
T: AsTargets<Elem = L> + Labels<Elem = L>,
25+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
2626
{
2727
type Object = MultinomialNb<F, L>;
2828
// Thin wrapper around the corresponding method of NaiveBayesValidParams
@@ -38,7 +38,7 @@ where
3838
F: Float,
3939
L: Label + 'a,
4040
D: Data<Elem = F>,
41-
T: AsTargets<Elem = L> + Labels<Elem = L>,
41+
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
4242
{
4343
type ObjectIn = Option<MultinomialNb<F, L>>;
4444
type ObjectOut = Option<MultinomialNb<F, L>>;
@@ -49,7 +49,7 @@ where
4949
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
5050
) -> Result<Self::ObjectOut> {
5151
let x = dataset.records();
52-
let y = dataset.try_single_target()?;
52+
let y = dataset.as_single_targets();
5353

5454
let mut model = match model_in {
5555
Some(temp) => temp,

algorithms/linfa-elasticnet/examples/elasticnet_cv.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use linfa::prelude::*;
22
use linfa_elasticnet::{ElasticNet, Result};
3+
use ndarray::arr0;
34

45
fn main() -> Result<()> {
56
// load Diabetes dataset (mutable to allow fast k-folding)
@@ -15,8 +16,9 @@ fn main() -> Result<()> {
1516
.collect::<Vec<_>>();
1617

1718
// get the mean r2 validation score across all folds for each model
18-
let r2_values =
19-
dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?;
19+
let r2_values = dataset.cross_validate(5, &models, |prediction, truth| {
20+
prediction.r2(&truth).map(arr0)
21+
})?;
2022

2123
for (ratio, r2) in ratios.iter().zip(r2_values.iter()) {
2224
println!("L1 ratio: {}, r2 score: {}", ratio, r2);

algorithms/linfa-elasticnet/src/algorithm.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
use approx::{abs_diff_eq, abs_diff_ne};
2+
use linfa::dataset::AsSingleTargets;
23
use ndarray::{s, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, CowArray, Data, Ix1, Ix2};
34
use ndarray_linalg::{Inverse, Lapack};
45

56
use linfa::traits::{Fit, PredictInplace};
6-
use linfa::{
7-
dataset::{AsTargets, Records},
8-
DatasetBase, Float,
9-
};
7+
use linfa::{dataset::Records, DatasetBase, Float};
108

119
use super::{hyperparams::ElasticNetValidParams, ElasticNet, ElasticNetError, Result};
1210

1311
impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, ElasticNetError> for ElasticNetValidParams<F>
1412
where
1513
F: Float + Lapack,
1614
D: Data<Elem = F>,
17-
T: AsTargets<Elem = F>,
15+
T: AsSingleTargets<Elem = F>,
1816
{
1917
type Object = ElasticNet<F>;
2018

@@ -29,7 +27,7 @@ where
2927
/// parameters and can be used to `predict` values of the target variable
3028
/// for new feature values.
3129
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
32-
let target = dataset.try_single_target()?;
30+
let target = dataset.as_single_targets();
3331

3432
let (intercept, y) = compute_intercept(self.with_intercept(), target);
3533
let (hyperplane, duality_gap, n_steps) = coordinate_descent(
@@ -217,15 +215,15 @@ fn duality_gap<'a, F: Float>(
217215
gap
218216
}
219217

220-
fn variance_params<F: Float + Lapack, T: AsTargets<Elem = F>, D: Data<Elem = F>>(
218+
fn variance_params<F: Float + Lapack, T: AsSingleTargets<Elem = F>, D: Data<Elem = F>>(
221219
ds: &DatasetBase<ArrayBase<D, Ix2>, T>,
222220
y_est: Array1<F>,
223221
) -> Result<Array1<F>> {
224222
let nfeatures = ds.nfeatures();
225223
let nsamples = ds.nsamples();
226224

227225
// try to convert targets into a single target
228-
let target = ds.try_single_target()?;
226+
let target = ds.as_single_targets();
229227

230228
// check that we have enough samples
231229
if nsamples < nfeatures + 1 {

algorithms/linfa-kernel/src/lib.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ impl<'a, F: Float, T: AsTargets, N: NearestNeighbour>
409409
}
410410
}
411411

412-
impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a, L>, N: NearestNeighbour>
412+
impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>, N: NearestNeighbour>
413413
Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T::View>>
414414
for KernelParams<F, N>
415415
{
@@ -432,7 +432,7 @@ impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a, L>, N: Ne
432432
/// not between 1 and #records-1
433433
fn transform(&self, x: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T::View> {
434434
let kernel = Kernel::new(x.records.view(), self);
435-
DatasetBase::new(kernel, T::new_targets_view(x.as_multi_targets()))
435+
DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
436436
}
437437
}
438438

@@ -443,7 +443,7 @@ impl<
443443
'b,
444444
F: Float,
445445
L: 'b,
446-
T: AsTargets<Elem = L> + FromTargetArray<'b, L>,
446+
T: AsTargets<Elem = L> + FromTargetArray<'b>,
447447
N: NearestNeighbour,
448448
> Transformer<&'b DatasetBase<ArrayView2<'a, F>, T>, DatasetBase<Kernel<F>, T::View>>
449449
for KernelParams<F, N>
@@ -471,7 +471,7 @@ impl<
471471
) -> DatasetBase<Kernel<F>, T::View> {
472472
let kernel = Kernel::new(x.records.view(), self);
473473

474-
DatasetBase::new(kernel, T::new_targets_view(x.as_multi_targets()))
474+
DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
475475
}
476476
}
477477

@@ -850,11 +850,7 @@ mod tests {
850850
check_kernel_from_dataset_view_type(&input.view(), KernelType::Sparse(3));
851851
}
852852

853-
fn check_kernel_from_dataset_type<
854-
'a,
855-
L: 'a,
856-
T: AsTargets<Elem = L> + FromTargetArray<'a, L>,
857-
>(
853+
fn check_kernel_from_dataset_type<'a, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>>(
858854
input: &'a DatasetBase<Array2<f64>, T>,
859855
k_type: KernelType,
860856
) {
@@ -889,7 +885,7 @@ mod tests {
889885
fn check_kernel_from_dataset_view_type<
890886
'a,
891887
L: 'a,
892-
T: AsTargets<Elem = L> + FromTargetArray<'a, L>,
888+
T: AsTargets<Elem = L> + FromTargetArray<'a>,
893889
>(
894890
input: &'a DatasetBase<ArrayView2<'a, f64>, T>,
895891
k_type: KernelType,

algorithms/linfa-linear/src/glm/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::error::{LinearError, Result};
88
use crate::float::{ArgminParam, Float};
99
use distribution::TweedieDistribution;
1010
use hyperparams::TweedieRegressorValidParams;
11+
use linfa::dataset::AsSingleTargets;
1112
pub use link::Link;
1213

1314
use argmin::core::{ArgminOp, Executor};
@@ -18,15 +19,15 @@ use ndarray::{Array, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2}
1819
use serde::{Deserialize, Serialize};
1920

2021
use linfa::traits::*;
21-
use linfa::{dataset::AsTargets, DatasetBase};
22+
use linfa::DatasetBase;
2223

23-
impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>, T, LinearError<F>>
24-
for TweedieRegressorValidParams<F>
24+
impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
25+
Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for TweedieRegressorValidParams<F>
2526
{
2627
type Object = TweedieRegressor<F>;
2728

2829
fn fit(&self, ds: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
29-
let (x, y) = (ds.records(), ds.try_single_target()?);
30+
let (x, y) = (ds.records(), ds.as_single_targets());
3031

3132
let dist = TweedieDistribution::new(self.power())?;
3233
let link = self.link();

algorithms/linfa-linear/src/ols.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ndarray::{concatenate, s, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1,
55
use ndarray_linalg::{Lapack, LeastSquaresSvdInto, Scalar};
66
use serde::{Deserialize, Serialize};
77

8-
use linfa::dataset::{AsTargets, DatasetBase};
8+
use linfa::dataset::{AsSingleTargets, DatasetBase};
99
use linfa::traits::{Fit, PredictInplace};
1010

1111
pub trait Float: linfa::Float + Lapack + Scalar {}
@@ -76,8 +76,8 @@ impl LinearRegression {
7676
}
7777
}
7878

79-
impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>, T, LinearError<F>>
80-
for LinearRegression
79+
impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
80+
Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for LinearRegression
8181
{
8282
type Object = FittedLinearRegression<F>;
8383

@@ -93,7 +93,7 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>,
9393
/// for new feature values.
9494
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
9595
let X = dataset.records();
96-
let y = dataset.try_single_target()?;
96+
let y = dataset.as_single_targets();
9797

9898
let (n_samples, _) = X.dim();
9999

algorithms/linfa-logistic/examples/logistic_cv.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use linfa::prelude::*;
22
use linfa_logistic::error::Result;
33
use linfa_logistic::LogisticRegression;
4+
use ndarray::arr0;
45

56
fn main() -> Result<()> {
67
// Load dataset. Mutability is needed for fast cross validation
@@ -22,7 +23,7 @@ fn main() -> Result<()> {
2223
// use cross validation to compute the validation accuracy of each model. The
2324
// accuracy of each model will be averaged across the folds, 5 in this case
2425
let accuracies = dataset.cross_validate(5, &models, |prediction, truth| {
25-
Ok(prediction.confusion_matrix(truth)?.accuracy())
26+
Ok(arr0(prediction.confusion_matrix(truth)?.accuracy()))
2627
})?;
2728

2829
// display the accuracy of the models along with their regularization coefficient

0 commit comments

Comments
 (0)