Skip to content

Commit dd014dc

Browse files
committed
Correct PCA component counting logic
BREAKING CHANGE: The change in component counting directly affects the number of clusters requested in the `main` auto-scaling function. For the same input data, this version may produce a different number of clusters and therefore different final scaling factors compared to previous versions. The new behavior is considered more accurate.
1 parent bc3e7d3 commit dd014dc

File tree

5 files changed

+900
-684
lines changed

5 files changed

+900
-684
lines changed

src/ert/analysis/misfit_preprocessor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,19 @@ def get_nr_primary_components(
3232
responses: npt.NDArray[np.float64], threshold: float
3333
) -> int:
3434
"""
35-
Calculate the number of principal components needed to achieve a cumulative
36-
variance less than a specified threshold using Singular Value Decomposition (SVD).
35+
Calculate the number of principal components required
36+
to explain a given amount of variance in the responses.
3737
38-
responses should be on form (n_realizations, n_observations)
38+
Args:
39+
responses: A 2D array of data with shape
40+
(n_realizations, n_observations).
41+
threshold: The cumulative variance threshold to meet or exceed.
42+
For example, a value of 0.95 will find the number of
43+
components needed to explain at least 95% of the total variance.
44+
45+
Returns:
46+
The minimum number of principal components required to meet or exceed
47+
the specified variance threshold.
3948
"""
4049
data_matrix = responses - responses.mean(axis=0)
4150
_, singulars, _ = np.linalg.svd(data_matrix.astype(float), full_matrices=False)
@@ -45,7 +54,10 @@ def get_nr_primary_components(
4554
# sum to get the cumulative proportion of variance explained by each successive
4655
# component.
4756
variance_ratio = np.cumsum(singulars**2) / np.sum(singulars**2)
48-
return max(len([1 for i in variance_ratio[:-1] if i < threshold]), 1)
57+
58+
num_components = np.searchsorted(variance_ratio, threshold, side="left") + 1
59+
60+
return int(num_components)
4961

5062

5163
def cluster_responses(

0 commit comments

Comments
 (0)