@@ -12,19 +12,26 @@ use super::Result;
12
12
derive( Serialize , Deserialize ) ,
13
13
serde( crate = "serde_crate" )
14
14
) ]
15
- /// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
16
- ///
17
- /// See [`ElasticNetParams`](crate::ElasticNetParams) for more informations.
18
- #[ derive( Clone , Debug , PartialEq ) ]
19
- pub struct ElasticNetValidParams < F > {
15
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
16
+ pub struct ElasticNetValidParamsBase < F , const MULTI_TASK : bool > {
20
17
penalty : F ,
21
18
l1_ratio : F ,
22
19
with_intercept : bool ,
23
20
max_iterations : u32 ,
24
21
tolerance : F ,
25
22
}
26
23
27
- impl < F : Float > ElasticNetValidParams < F > {
24
+ /// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
25
+ ///
26
+ /// See [`ElasticNetParams`](crate::ElasticNetParams) for more information.
27
+ pub type ElasticNetValidParams < F > = ElasticNetValidParamsBase < F , false > ;
28
+
29
+ /// A verified hyper-parameter set ready for the estimation of a multi-task ElasticNet regression model
30
+ ///
31
+ /// See [`MultiTaskElasticNetParams`](crate::MultiTaskElasticNetParams) for more information.
32
+ pub type MultiTaskElasticNetValidParams < F > = ElasticNetValidParamsBase < F , true > ;
33
+
34
+ impl < F : Float , const MULTI_TASK : bool > ElasticNetValidParamsBase < F , MULTI_TASK > {
28
35
pub fn penalty ( & self ) -> F {
29
36
self . penalty
30
37
}
@@ -46,7 +53,12 @@ impl<F: Float> ElasticNetValidParams<F> {
46
53
}
47
54
}
48
55
49
- /// A hyper-parameter set during construction
56
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
57
+ pub struct ElasticNetParamsBase < F , const MULTI_TASK : bool > (
58
+ ElasticNetValidParamsBase < F , MULTI_TASK > ,
59
+ ) ;
60
+
61
+ /// A hyper-parameter set for Elastic-Net
50
62
///
51
63
/// Configures and minimizes the following objective function:
52
64
/// ```ignore
@@ -57,18 +69,18 @@ impl<F: Float> ElasticNetValidParams<F> {
57
69
///
58
70
/// The parameter set can be verified into a
59
71
/// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling
60
- /// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
72
+ /// [ParamGuard::check](Self::check() ). It is also possible to directly fit a model with
61
73
/// [Fit::fit](linfa::traits::Fit::fit) which implicitely verifies the parameter set prior to the
62
74
/// model estimation and forwards any error.
63
75
///
64
76
/// # Parameters
65
77
/// | Name | Default | Purpose | Range |
66
78
/// | :--- | :--- | :---| :--- |
67
- /// | [penalty](Self::penalty) | `1.0` | Overall parameter penalty | `[0, inf)` |
68
- /// | [l1_ratio](Self::l1_ratio) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
69
- /// | [with_intercept](Self::with_intercept) | `true` | Enable intercept | `false`, `true` |
70
- /// | [tolerance](Self::tolerance) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
71
- /// | [max_iterations](Self::max_iterations) | `1000` | Maximum number of iterations | `[1, inf)` |
79
+ /// | [penalty](Self::penalty() ) | `1.0` | Overall parameter penalty | `[0, inf)` |
80
+ /// | [l1_ratio](Self::l1_ratio() ) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
81
+ /// | [with_intercept](Self::with_intercept() ) | `true` | Enable intercept | `false`, `true` |
82
+ /// | [tolerance](Self::tolerance() ) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
83
+ /// | [max_iterations](Self::max_iterations() ) | `1000` | Maximum number of iterations | `[1, inf)` |
72
84
///
73
85
/// # Errors
74
86
///
@@ -105,17 +117,55 @@ impl<F: Float> ElasticNetValidParams<F> {
105
117
/// let model = checked_params.fit(&ds)?;
106
118
/// # Ok::<(), ElasticNetError>(())
107
119
/// ```
108
- #[ derive( Clone , Debug , PartialEq ) ]
109
- pub struct ElasticNetParams < F > ( ElasticNetValidParams < F > ) ;
120
+ pub type ElasticNetParams < F > = ElasticNetParamsBase < F , false > ;
121
+
122
+ /// A hyper-parameter set for multi-task Elastic-Net
123
+ ///
124
+ /// The multi-task version (Y becomes a measurement matrix) is also supported and
125
+ /// solves the following objective function:
126
+ /// ```ignore
127
+ /// 1 / (2 * n_samples) * || Y - XW ||^2_F
128
+ /// + penalty * l1_ratio * ||W||_2,1
129
+ /// + 0.5 * penalty * (1 - l1_ratio) * ||W||^2_F
130
+ /// ```
131
+ ///
132
+ /// See [`ElasticNetParams`](crate::ElasticNetParams) for information on parameters and return
133
+ /// values.
134
+ ///
135
+ /// # Example
136
+ ///
137
+ /// ```rust
138
+ /// use linfa_elasticnet::{MultiTaskElasticNetParams, ElasticNetError};
139
+ /// use linfa::prelude::*;
140
+ /// use ndarray::array;
141
+ ///
142
+ /// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![[3.0, 1.1], [2.0, 2.2]]);
143
+ ///
144
+ /// // create a new parameter set with penalty equals `1e-5`
145
+ /// let unchecked_params = MultiTaskElasticNetParams::new()
146
+ /// .penalty(1e-5);
147
+ ///
148
+ /// // fit model with unchecked parameter set
149
+ /// let model = unchecked_params.fit(&ds)?;
150
+ ///
151
+ /// // transform into a verified parameter set
152
+ /// let checked_params = unchecked_params.check()?;
153
+ ///
154
+ /// // Regenerate model with the verified parameters, this only returns
155
+ /// // errors originating from the fitting process
156
+ /// let model = checked_params.fit(&ds)?;
157
+ /// # Ok::<(), ElasticNetError>(())
158
+ /// ```
159
+ pub type MultiTaskElasticNetParams < F > = ElasticNetParamsBase < F , true > ;
110
160
111
- impl < F : Float > Default for ElasticNetParams < F > {
161
+ impl < F : Float , const MULTI_TASK : bool > Default for ElasticNetParamsBase < F , MULTI_TASK > {
112
162
fn default ( ) -> Self {
113
163
Self :: new ( )
114
164
}
115
165
}
116
166
117
167
/// Configure and fit a Elastic Net model
118
- impl < F : Float > ElasticNetParams < F > {
168
+ impl < F : Float , const MULTI_TASK : bool > ElasticNetParamsBase < F , MULTI_TASK > {
119
169
/// Create default elastic net hyper parameters
120
170
///
121
171
/// By default, an intercept will be fitted. To disable fitting an
@@ -124,8 +174,8 @@ impl<F: Float> ElasticNetParams<F> {
124
174
/// To additionally normalize the feature matrix before fitting, call
125
175
/// `fit_intercept_and_normalize()` before calling `fit()`. The feature
126
176
/// matrix will not be normalized by default.
127
- pub fn new ( ) -> ElasticNetParams < F > {
128
- Self ( ElasticNetValidParams {
177
+ pub fn new ( ) -> ElasticNetParamsBase < F , MULTI_TASK > {
178
+ Self ( ElasticNetValidParamsBase {
129
179
penalty : F :: one ( ) ,
130
180
l1_ratio : F :: cast ( 0.5 ) ,
131
181
with_intercept : true ,
@@ -134,7 +184,7 @@ impl<F: Float> ElasticNetParams<F> {
134
184
} )
135
185
}
136
186
137
- /// Set the overall parameter penalty parameter of the elastic net.
187
+ /// Set the overall parameter penalty parameter of the elastic net, otherwise known as `alpha` .
138
188
/// Use `l1_ratio` to configure how the penalty distributed to L1 and L2
139
189
/// regularization.
140
190
pub fn penalty ( mut self , penalty : F ) -> Self {
@@ -180,8 +230,8 @@ impl<F: Float> ElasticNetParams<F> {
180
230
}
181
231
}
182
232
183
- impl < F : Float > ParamGuard for ElasticNetParams < F > {
184
- type Checked = ElasticNetValidParams < F > ;
233
+ impl < F : Float , const MULTI_TASK : bool > ParamGuard for ElasticNetParamsBase < F , MULTI_TASK > {
234
+ type Checked = ElasticNetValidParamsBase < F , MULTI_TASK > ;
185
235
type Error = ElasticNetError ;
186
236
187
237
/// Validate the hyper parameters
0 commit comments