Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def _convert_from_pandas(

return X

# expose the method as public so can create predictions outside of class
convert_from_pandas = _convert_from_pandas

def _set_up_for_fit(self, y: np.ndarray) -> None:
#######################################################################
# 1. input validation #
Expand Down
10 changes: 7 additions & 3 deletions src/glum/_glm_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def fit(
_stype = ["csc"]
else:
_stype = ["csc", "csr"]

def _fit_path(
self,
train_idx,
Expand Down Expand Up @@ -571,6 +571,8 @@ def _fit_path(
y[test_idx],
sample_weight[test_idx],
)
# test weights need to sum to 1 too, else deviance is not properly scaled
w_test /= w_test.sum()

if offset is not None:
offset_train = offset[train_idx]
Expand Down Expand Up @@ -667,8 +669,8 @@ def _get_deviance(coef):
)
deviance_path_ = [_get_deviance(_coef) for _coef in coef_path_]

return intercept_path_, coef_path_, deviance_path_

return intercept_path_, coef_path_, deviance_path_, train_idx
jobs = (
joblib.delayed(_fit_path)(
self,
Expand Down Expand Up @@ -706,6 +708,8 @@ def _get_deviance(coef):
(cv.get_n_splits(), len(l1_ratio), len(alphas[0])),
)

self.train_indices_ = [elmt[3] for elmt in paths_data]

avg_deviance = self.deviance_path_.mean(axis=0) # type: ignore

best_l1, best_alpha = np.unravel_index(
Expand Down
Loading