Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 21357e2

Browse files
YuhanLiinPABannier
andauthoredNov 12, 2022
Adding Multi-Task ElasticNet support (rust-ml#238)
* added block coordinate descent function * added duality_gap_mtl computation * ENH cd pass to be consistent with bcd * added prox operator for MTL Enet * added helper functions for tests * working ent mtl penalties * bcd lower objective test pass * added MultiTaskEnet struct * added MTENET documentation * added API MTENET * added variance, z-score, conf interval for multitask ENET * added multi-task estimators * added tests for MTL * added tests for Enet and MTL * WIP: made variance params generic over the number of tasks * added z_score and confidence_95th for MTL * WIP make compute_variance generic over the dimension * Replace for loops in block_coordinate_descent with general_mat_mul calls * Bring back generic compute_intercept * Replace manual norm calculations with norm trait calls * Add docs and derives to multi task types * Add example for multitask_elasticnet * Rename shape() calls to nrows and ncols Co-authored-by: Pierre-Antoine Bannier <[email protected]>
1 parent 44b244c commit 21357e2

File tree

7 files changed

+754
-67
lines changed

7 files changed

+754
-67
lines changed
 

‎algorithms/linfa-elasticnet/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ thiserror = "1.0"
4040
linfa = { version = "0.6.0", path = "../.." }
4141

4242
[dev-dependencies]
43-
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes"] }
43+
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes", "linnerud"] }
4444
ndarray-rand = "0.14"
4545
rand_xoshiro = "0.6"

‎algorithms/linfa-elasticnet/examples/elasticnet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ fn main() -> Result<()> {
55
// load Diabetes dataset
66
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);
77

8-
// train pure LASSO model with 0.1 penalty
8+
// train pure LASSO model with 0.3 penalty
99
let model = ElasticNet::params()
1010
.penalty(0.3)
1111
.l1_ratio(1.0)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use linfa::prelude::*;
2+
use linfa_elasticnet::{MultiTaskElasticNet, Result};
3+
4+
fn main() -> Result<()> {
5+
// load Diabetes dataset
6+
let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80);
7+
8+
// train pure LASSO model with 0.1 penalty
9+
let model = MultiTaskElasticNet::params()
10+
.penalty(0.1)
11+
.l1_ratio(1.0)
12+
.fit(&train)?;
13+
14+
println!("intercept: {}", model.intercept());
15+
println!("params: {}", model.hyperplane());
16+
17+
println!("z score: {:?}", model.z_score());
18+
19+
// validate
20+
let y_est = model.predict(&valid);
21+
println!("predicted variance: {}", y_est.r2(&valid)?);
22+
23+
Ok(())
24+
}

‎algorithms/linfa-elasticnet/src/algorithm.rs

Lines changed: 607 additions & 41 deletions
Large diffs are not rendered by default.

‎algorithms/linfa-elasticnet/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ pub enum ElasticNetError {
2525
InvalidPenalty(f32),
2626
#[error("invalid tolerance {0}")]
2727
InvalidTolerance(f32),
28+
#[error("the target can either be a vector (ndim=1) or a matrix (ndim=2)")]
29+
IncorrectTargetShape,
2830
#[error(transparent)]
2931
BaseCrate(#[from] linfa::Error),
3032
}

‎algorithms/linfa-elasticnet/src/hyperparams.rs

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@ use super::Result;
1212
derive(Serialize, Deserialize),
1313
serde(crate = "serde_crate")
1414
)]
15-
/// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
16-
///
17-
/// See [`ElasticNetParams`](crate::ElasticNetParams) for more informations.
18-
#[derive(Clone, Debug, PartialEq)]
19-
pub struct ElasticNetValidParams<F> {
15+
#[derive(Clone, Debug, PartialEq, Eq)]
16+
pub struct ElasticNetValidParamsBase<F, const MULTI_TASK: bool> {
2017
penalty: F,
2118
l1_ratio: F,
2219
with_intercept: bool,
2320
max_iterations: u32,
2421
tolerance: F,
2522
}
2623

27-
impl<F: Float> ElasticNetValidParams<F> {
24+
/// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
25+
///
26+
/// See [`ElasticNetParams`](crate::ElasticNetParams) for more information.
27+
pub type ElasticNetValidParams<F> = ElasticNetValidParamsBase<F, false>;
28+
29+
/// A verified hyper-parameter set ready for the estimation of a multi-task ElasticNet regression model
30+
///
31+
/// See [`MultiTaskElasticNetParams`](crate::MultiTaskElasticNetParams) for more information.
32+
pub type MultiTaskElasticNetValidParams<F> = ElasticNetValidParamsBase<F, true>;
33+
34+
impl<F: Float, const MULTI_TASK: bool> ElasticNetValidParamsBase<F, MULTI_TASK> {
2835
pub fn penalty(&self) -> F {
2936
self.penalty
3037
}
@@ -46,7 +53,12 @@ impl<F: Float> ElasticNetValidParams<F> {
4653
}
4754
}
4855

49-
/// A hyper-parameter set during construction
56+
#[derive(Clone, Debug, PartialEq, Eq)]
57+
pub struct ElasticNetParamsBase<F, const MULTI_TASK: bool>(
58+
ElasticNetValidParamsBase<F, MULTI_TASK>,
59+
);
60+
61+
/// A hyper-parameter set for Elastic-Net
5062
///
5163
/// Configures and minimizes the following objective function:
5264
/// ```ignore
@@ -57,18 +69,18 @@ impl<F: Float> ElasticNetValidParams<F> {
5769
///
5870
/// The parameter set can be verified into a
5971
/// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling
60-
/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
72+
/// [ParamGuard::check](Self::check()). It is also possible to directly fit a model with
6173
/// [Fit::fit](linfa::traits::Fit::fit) which implicitely verifies the parameter set prior to the
6274
/// model estimation and forwards any error.
6375
///
6476
/// # Parameters
6577
/// | Name | Default | Purpose | Range |
6678
/// | :--- | :--- | :---| :--- |
67-
/// | [penalty](Self::penalty) | `1.0` | Overall parameter penalty | `[0, inf)` |
68-
/// | [l1_ratio](Self::l1_ratio) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
69-
/// | [with_intercept](Self::with_intercept) | `true` | Enable intercept | `false`, `true` |
70-
/// | [tolerance](Self::tolerance) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
71-
/// | [max_iterations](Self::max_iterations) | `1000` | Maximum number of iterations | `[1, inf)` |
79+
/// | [penalty](Self::penalty()) | `1.0` | Overall parameter penalty | `[0, inf)` |
80+
/// | [l1_ratio](Self::l1_ratio()) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
81+
/// | [with_intercept](Self::with_intercept()) | `true` | Enable intercept | `false`, `true` |
82+
/// | [tolerance](Self::tolerance()) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
83+
/// | [max_iterations](Self::max_iterations()) | `1000` | Maximum number of iterations | `[1, inf)` |
7284
///
7385
/// # Errors
7486
///
@@ -105,17 +117,55 @@ impl<F: Float> ElasticNetValidParams<F> {
105117
/// let model = checked_params.fit(&ds)?;
106118
/// # Ok::<(), ElasticNetError>(())
107119
/// ```
108-
#[derive(Clone, Debug, PartialEq)]
109-
pub struct ElasticNetParams<F>(ElasticNetValidParams<F>);
120+
pub type ElasticNetParams<F> = ElasticNetParamsBase<F, false>;
121+
122+
/// A hyper-parameter set for multi-task Elastic-Net
123+
///
124+
/// The multi-task version (Y becomes a measurement matrix) is also supported and
125+
/// solves the following objective function:
126+
/// ```ignore
127+
/// 1 / (2 * n_samples) * || Y - XW ||^2_F
128+
/// + penalty * l1_ratio * ||W||_2,1
129+
/// + 0.5 * penalty * (1 - l1_ratio) * ||W||^2_F
130+
/// ```
131+
///
132+
/// See [`ElasticNetParams`](crate::ElasticNetParams) for information on parameters and return
133+
/// values.
134+
///
135+
/// # Example
136+
///
137+
/// ```rust
138+
/// use linfa_elasticnet::{MultiTaskElasticNetParams, ElasticNetError};
139+
/// use linfa::prelude::*;
140+
/// use ndarray::array;
141+
///
142+
/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![[3.0, 1.1], [2.0, 2.2]]);
143+
///
144+
/// // create a new parameter set with penalty equals `1e-5`
145+
/// let unchecked_params = MultiTaskElasticNetParams::new()
146+
/// .penalty(1e-5);
147+
///
148+
/// // fit model with unchecked parameter set
149+
/// let model = unchecked_params.fit(&ds)?;
150+
///
151+
/// // transform into a verified parameter set
152+
/// let checked_params = unchecked_params.check()?;
153+
///
154+
/// // Regenerate model with the verified parameters, this only returns
155+
/// // errors originating from the fitting process
156+
/// let model = checked_params.fit(&ds)?;
157+
/// # Ok::<(), ElasticNetError>(())
158+
/// ```
159+
pub type MultiTaskElasticNetParams<F> = ElasticNetParamsBase<F, true>;
110160

111-
impl<F: Float> Default for ElasticNetParams<F> {
161+
impl<F: Float, const MULTI_TASK: bool> Default for ElasticNetParamsBase<F, MULTI_TASK> {
112162
fn default() -> Self {
113163
Self::new()
114164
}
115165
}
116166

117167
/// Configure and fit a Elastic Net model
118-
impl<F: Float> ElasticNetParams<F> {
168+
impl<F: Float, const MULTI_TASK: bool> ElasticNetParamsBase<F, MULTI_TASK> {
119169
/// Create default elastic net hyper parameters
120170
///
121171
/// By default, an intercept will be fitted. To disable fitting an
@@ -124,8 +174,8 @@ impl<F: Float> ElasticNetParams<F> {
124174
/// To additionally normalize the feature matrix before fitting, call
125175
/// `fit_intercept_and_normalize()` before calling `fit()`. The feature
126176
/// matrix will not be normalized by default.
127-
pub fn new() -> ElasticNetParams<F> {
128-
Self(ElasticNetValidParams {
177+
pub fn new() -> ElasticNetParamsBase<F, MULTI_TASK> {
178+
Self(ElasticNetValidParamsBase {
129179
penalty: F::one(),
130180
l1_ratio: F::cast(0.5),
131181
with_intercept: true,
@@ -134,7 +184,7 @@ impl<F: Float> ElasticNetParams<F> {
134184
})
135185
}
136186

137-
/// Set the overall parameter penalty parameter of the elastic net.
187+
/// Set the overall parameter penalty parameter of the elastic net, otherwise known as `alpha`.
138188
/// Use `l1_ratio` to configure how the penalty distributed to L1 and L2
139189
/// regularization.
140190
pub fn penalty(mut self, penalty: F) -> Self {
@@ -180,8 +230,8 @@ impl<F: Float> ElasticNetParams<F> {
180230
}
181231
}
182232

183-
impl<F: Float> ParamGuard for ElasticNetParams<F> {
184-
type Checked = ElasticNetValidParams<F>;
233+
impl<F: Float, const MULTI_TASK: bool> ParamGuard for ElasticNetParamsBase<F, MULTI_TASK> {
234+
type Checked = ElasticNetValidParamsBase<F, MULTI_TASK>;
185235
type Error = ElasticNetError;
186236

187237
/// Validate the hyper parameters

‎algorithms/linfa-elasticnet/src/lib.rs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![doc = include_str!("../README.md")]
22

33
use linfa::Float;
4-
use ndarray::Array1;
4+
use ndarray::{Array1, Array2};
55

66
#[cfg(feature = "serde")]
77
use serde_crate::{Deserialize, Serialize};
@@ -11,7 +11,10 @@ mod error;
1111
mod hyperparams;
1212

1313
pub use error::{ElasticNetError, Result};
14-
pub use hyperparams::{ElasticNetParams, ElasticNetValidParams};
14+
pub use hyperparams::{
15+
ElasticNetParams, ElasticNetParamsBase, ElasticNetValidParams, ElasticNetValidParamsBase,
16+
MultiTaskElasticNetParams, MultiTaskElasticNetValidParams,
17+
};
1518

1619
#[cfg_attr(
1720
feature = "serde",
@@ -66,3 +69,45 @@ impl<F: Float> ElasticNet<F> {
6669
ElasticNetParams::new().l1_ratio(F::one())
6770
}
6871
}
72+
73+
#[cfg_attr(
74+
feature = "serde",
75+
derive(Serialize, Deserialize),
76+
serde(crate = "serde_crate")
77+
)]
78+
/// MultiTask Elastic Net model
79+
///
80+
/// This struct contains the parameters of a fitted multi-task elastic net model. This includes the
81+
/// coefficients (a 2-dimensional array), (optionally) intercept (a 1-dimensional array), duality gaps
82+
/// and the number of steps needed in the computation.
83+
///
84+
/// ## Model implementation
85+
///
86+
/// The block coordinate descent is widely used to solve generalized linear models optimization problems,
87+
/// like Group Lasso, MultiTask Ridge or MultiTask Lasso. It cycles through a group of parameters and update
88+
/// the groups separately, holding all the others fixed. The optimization routine stops when a criterion is
89+
/// satisfied (dual sub-optimality gap or change in coefficients).
90+
#[derive(Debug, Clone)]
91+
pub struct MultiTaskElasticNet<F> {
92+
hyperplane: Array2<F>,
93+
intercept: Array1<F>,
94+
duality_gap: F,
95+
n_steps: u32,
96+
variance: Result<Array1<F>>,
97+
}
98+
99+
impl<F: Float> MultiTaskElasticNet<F> {
100+
pub fn params() -> MultiTaskElasticNetParams<F> {
101+
MultiTaskElasticNetParams::new()
102+
}
103+
104+
/// Create a multi-task ridge only model
105+
pub fn ridge() -> MultiTaskElasticNetParams<F> {
106+
MultiTaskElasticNetParams::new().l1_ratio(F::zero())
107+
}
108+
109+
/// Create a multi-task Lasso only model
110+
pub fn lasso() -> MultiTaskElasticNetParams<F> {
111+
MultiTaskElasticNetParams::new().l1_ratio(F::one())
112+
}
113+
}

0 commit comments

Comments
 (0)
Please sign in to comment.