-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use seed! to put every copy of rng into a unique state #198
Conversation
Using `rand(_rng, i)` didn't really put all copies of `rng` into a unique state, the states were still interlocked (all the generators produced same sequence of random numbers with some offset). Calling ` seed!` with a deterministic, pseudo-random seed for each thread produces much better results, which is also visible in the classification and regression accuracies produced by the tests.
Thanks @dhanak for this valuable contribution. @rikhuijzer I really think you are in the best position to review this PR, if you don't mind? |
I think it's safe to assume that Below are the accuracy comparisons of what is currently the _rng = Random.seed!(copy(rng), i) diff --git a/tree-old.txt b/tree-new.txt
index 5086247..f7fe469 100644
--- a/tree-old.txt
+++ b/tree-new.txt
@@ -48,64 +48,64 @@ Mean Accuracy: 0.8688688688688688
##### nfoldCV Classification Forest #####
Testing nfoldCV_forest
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8448448448448449
+Mean Accuracy: 0.908908908908909
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
Fold 1
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 6 3 0 0
- 0 121 29 0
- 0 26 136 7
- 0 0 4 1
+ 4 5 0 0
+ 1 124 25 0
+ 0 4 165 0
+ 0 0 5 0
-Accuracy: 0.7927927927927928
-Kappa: 0.6153446948136739
+Accuracy: 0.8798798798798799
+Kappa: 0.7701030394035107
Fold 2
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 9 1 0 0
- 5 123 15 0
- 0 16 156 2
- 0 0 6 0
+ 8 2 0 0
+ 0 128 15 0
+ 0 10 164 0
+ 0 0 4 2
-Accuracy: 0.8648648648648649
-Kappa: 0.7499123817153157
+Accuracy: 0.9069069069069069
+Kappa: 0.8248409264443879
Fold 3
Classes: [-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 7 3 0 0
- 3 121 33 0
- 0 11 144 4
+ 2 8 0 0
+ 0 141 16 0
+ 0 4 155 0
0 0 6 1
-Accuracy: 0.8198198198198198
-Kappa: 0.6695445072938374
+Accuracy: 0.8978978978978979
+Kappa: 0.807114382091383
-Mean Accuracy: 0.8258258258258259
+Mean Accuracy: 0.8948948948948949
##### nfoldCV Adaboosted Stumps #####
Testing nfoldCV_stumps
@@ -179,37 +179,37 @@ Mean Accuracy: 0.9629629629629629
Fold 1
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 14 9 0 0
- 2 130 7 0
- 0 10 140 2
- 0 1 2 16
+ 17 6 0 0
+ 0 135 4 0
+ 0 8 144 0
+ 0 0 4 15
-Accuracy: 0.9009009009009009
-Kappa: 0.83520043190714
+Accuracy: 0.933933933933934
+Kappa: 0.8896653513660049
Fold 2
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
- 10 13 0 0
- 1 139 16 0
- 0 13 120 1
- 0 0 10 10
+ 13 10 0 0
+ 0 143 13 0
+ 0 7 125 2
+ 0 0 1 19
-Accuracy: 0.8378378378378378
-Kappa: 0.7238297088094361
+Accuracy: 0.9009009009009009
+Kappa: 0.8349603508350355
Fold 3
Classes: Int32[-2, -1, 0, 1]
Matrix: 4×4 Matrix{Int64}:
16 1 0 0
- 1 126 10 0
+ 0 127 10 0
0 1 150 0
- 0 0 7 21
+ 0 0 10 18
-Accuracy: 0.93993993993994
-Kappa: 0.9009797945256397
+Accuracy: 0.933933933933934
+Kappa: 0.8902800658978584
-Mean Accuracy: 0.8928928928928929
+Mean Accuracy: 0.9229229229229229
##### nfoldCV Adaboosted Stumps #####
@@ -265,13 +265,13 @@ Feature 3 < 2.45 ?
└─ Iris-virginica : 1/1
└─ Feature 4 < 1.55 ?
├─ Iris-virginica : 3/3
- └─ Feature 1 < 6.95 ?
+ └─ Feature 3 < 5.45 ?
├─ Iris-versicolor : 2/2
└─ Iris-virginica : 1/1
└─ Feature 3 < 4.85 ?
- ├─ Feature 2 < 3.1 ?
- ├─ Iris-virginica : 2/2
- └─ Iris-versicolor : 1/1
+ ├─ Feature 1 < 5.95 ?
+ ├─ Iris-versicolor : 1/1
+ └─ Iris-virginica : 2/2
└─ Iris-virginica : 43/43
##### nfoldCV Classification Tree #####
@@ -314,33 +314,33 @@ Fold 1
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
20 0 0
- 0 20 1
+ 0 18 3
0 1 8
-Accuracy: 0.96
-Kappa: 0.9366286438529784
+Accuracy: 0.92
+Kappa: 0.8751560549313357
Fold 2
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
15 0 0
0 15 1
- 0 3 16
+ 0 2 17
-Accuracy: 0.92
-Kappa: 0.8798076923076925
+Accuracy: 0.94
+Kappa: 0.9096929560505719
Fold 3
Classes: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
Matrix: 3×3 Matrix{Int64}:
15 0 0
- 0 12 1
- 0 7 15
+ 0 13 0
+ 0 3 19
-Accuracy: 0.84
-Kappa: 0.7613365155131264
+Accuracy: 0.94
+Kappa: 0.9090357792601576
-Mean Accuracy: 0.9066666666666666
+Mean Accuracy: 0.9333333333333332
##### nfoldCV Classification Adaboosted Stumps #####
@@ -385,7 +385,7 @@ Mean Accuracy: 0.8109892809975735
##### 3 foldCV Classification Forest #####
-Mean Accuracy: 0.8270217144261188
+Mean Accuracy: 0.8429005804846587
##### nfoldCV Classification Adaboosted Stumps #####
@@ -422,21 +422,21 @@ Mean Coeff of Determination: 0.821479058935842
##### nfoldCV Regression Forest #####
Fold 1
-Mean Squared Error: 2.0183096134238294
-Correlation Coeff: 0.8903914722230327
-Coeff of Determination: 0.7924911697044006
+Mean Squared Error: 1.3577742526795888
+Correlation Coeff: 0.9396271935146402
+Coeff of Determination: 0.8604029108789377
Fold 2
-Mean Squared Error: 1.9714838724549328
-Correlation Coeff: 0.910241766877058
-Coeff of Determination: 0.8011434924520122
+Mean Squared Error: 1.3034832328733625
+Correlation Coeff: 0.9529278684745566
+Coeff of Determination: 0.8685223212027657
Fold 3
-Mean Squared Error: 1.6739772387561769
-Correlation Coeff: 0.9029059136519314
-Coeff of Determination: 0.813068012307753
+Mean Squared Error: 1.1485186853278506
+Correlation Coeff: 0.9420191589030741
+Coeff of Determination: 0.8717456392002396
-Mean Coeff of Determination: 0.8022342248213886
+Mean Coeff of Determination: 0.8668902904273144
==================================================
TEST: regression/digits.jl
@@ -447,7 +447,7 @@ Mean Coeff of Determination: 0.6349826429860214
##### 3 foldCV Regression Forest #####
-Mean Coeff of Determination: 0.5825527898815513
+Mean Coeff of Determination: 0.6477805012747754
==================================================
TEST: regression/scikitlearn.jl
@@ -496,5 +496,5 @@ TEST: miscellaneous/feature_importance_test.jl
==================================================
Test Summary: | Pass Total Time
-Test Suites | 9658 9658 53.0s
+Test Suites | 9612 9612 53.6s
Testing DecisionTree tests passed What do you think @dhanak? |
I agree on the assumption, that is why using In your version, every tree with a specific index draws the same sequence of numbers for each invocation, given a specific class of |
Now I get it. Thanks, David 😄 @ablaom Can you merge this and create a release? I don't yet understand how to create releases in the MLJ-style, unfortunately. |
Sure, I'll take care of it. FYI: new release instructions. |
Thanks @dhanak for this valuable contribution. Thank you @rikhuijzer for your generous engagement and review. 🙏🏾 |
Fixes #194.
Using
rand(_rng, i)
didn't really put all copies ofrng
into a unique state, the states were still interlocked (all the generators produced same sequence of random numbers with some offset). Callingseed!
with a deterministic, pseudo-random seed for each thread produces much better results, which is also visible in the classification and regression accuracies produced by the tests.Here is the diff output of the unit tests before and after the change: