Skip to content

Commit c436df0

Browse files
Added convergence flag to fit() (#189)
* Added convergence flag to fit() * [pre-commit.ci] auto fixes from pre-commit hooks * Added news * Update snmf_class.py Was giving me errors due to use of python 3.11 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d56e1b0 commit c436df0

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

news/add-convergence-flag.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
**Added:**
2+
3+
* SNMFOptimizer.converged_ attribute to indicate whether the optimization
4+
successfully reached the convergence tolerance (True) or stopped because the
5+
maximum number of iterations was reached (False).
6+
7+
**Changed:**
8+
9+
* <news item>
10+
11+
**Deprecated:**
12+
13+
* <news item>
14+
15+
**Removed:**
16+
17+
* <news item>
18+
19+
**Fixed:**
20+
21+
* <news item>
22+
23+
**Security:**
24+
25+
* <news item>

src/diffpy/stretched_nmf/snmf_class.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def fit(self, rho=0, eta=0, reset=True):
210210
the output of the previous
211211
fit() as their input.
212212
"""
213+
self.converged_ = False
213214

214215
if reset:
215216
self.components_ = self.init_components.copy()
@@ -251,11 +252,12 @@ def fit(self, rho=0, eta=0, reset=True):
251252
sparsity_term = self.eta * np.sum(
252253
np.sqrt(self.components_)
253254
) # Square root penalty
255+
obj_diff = (
256+
self.objective_function - regularization_term - sparsity_term
257+
)
254258
print(
255259
f"Start, Objective function: {self.objective_function:.5e}"
256-
f", Obj - reg/sparse: {self.objective_function
257-
- regularization_term
258-
- sparsity_term:.5e}"
260+
f", Obj - reg/sparse: {obj_diff:.5e}"
259261
)
260262

261263
# Main optimization loop
@@ -274,11 +276,12 @@ def fit(self, rho=0, eta=0, reset=True):
274276
sparsity_term = self.eta * np.sum(
275277
np.sqrt(self.components_)
276278
) # Square root penalty
279+
obj_diff = (
280+
self.objective_function - regularization_term - sparsity_term
281+
)
277282
print(
278283
f"Obj fun: {self.objective_function:.5e}, "
279-
f"Obj - reg/sparse: {self.objective_function
280-
- regularization_term
281-
- sparsity_term:.5e}, "
284+
f", Obj - reg/sparse: {obj_diff:.5e}"
282285
f"Iter: {self.outiter}"
283286
)
284287

@@ -294,6 +297,7 @@ def fit(self, rho=0, eta=0, reset=True):
294297
self.objective_difference < self.objective_function * self.tol
295298
and outiter >= self.min_iter
296299
):
300+
self.converged_ = True
297301
break
298302

299303
self.normalize_results()

0 commit comments

Comments
 (0)