Skip to content

Commit 310c079

Browse files
committed
remove redundant "minibatch" algorithm
1 parent 00aa651 commit 310c079

File tree

4 files changed

+21
-26
lines changed

4 files changed

+21
-26
lines changed

examples/minibatch_algorithms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def progress_callback(nmf_instance: TransformInvariantNMF, iteration: int) -> bo
9898
for params in (
9999
dict(),
100100
#
101-
dict(algorithm=MiniBatchAlgorithm.Basic_MU, batch_size=10),
102101
dict(algorithm=MiniBatchAlgorithm.Cyclic_MU, batch_size=10),
103102
dict(algorithm=MiniBatchAlgorithm.ASG_MU, batch_size=10),
104103
dict(algorithm=MiniBatchAlgorithm.GSG_MU, batch_size=10),

tnmf/TransformInvariantNMF.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class MiniBatchAlgorithm(Enum):
4848
r"""
4949
MiniBatch algorithms that can be used with :meth:`.TransformInvariantNMF.fit_minibatch`.
5050
"""
51-
Basic_MU = 3 # Algorithm 3 Basic alternating scheme for MU rules
5251
Cyclic_MU = 4 # Algorithm 4 Cyclic mini-batch for MU rules
5352
ASG_MU = 5 # Algorithm 5 Asymmetric SG mini-batch MU rules (ASG-MU)
5453
GSG_MU = 6 # Algorithm 6 Greedy SG mini-batch MU rules (GSG-MU)
@@ -352,9 +351,9 @@ def fit_batch(
352351
def fit_minibatches(
353352
self,
354353
V: np.ndarray,
355-
algorithm: MiniBatchAlgorithm = MiniBatchAlgorithm.Basic_MU,
354+
algorithm: MiniBatchAlgorithm = MiniBatchAlgorithm.ASG_MU,
356355
batch_size: int = 3,
357-
n_epochs: int = 1000, # corresponds to max_iter if algorithm == MiniBatchAlgorithm.Basic_MU
356+
n_epochs: int = 1000,
358357
sag_lambda: float = 0.2,
359358
keep_W: bool = False,
360359
sparsity_H: float = 0.,
@@ -375,10 +374,9 @@ def fit_minibatches(
375374
algorithm: MiniBatchAlgorithm
376375
MiniBatch update scheme to be used. See :class:`MiniBatchAlgorithm` and [3]_ for the different choices.
377376
batch_size: int, default = 3
378-
Number of samples per mini batch. Ignored if algorithm==MiniBatchAlgorithm.Basic_MU
377+
Number of samples per mini batch.
379378
n_epochs: int, default = 1000
380-
Maximum number of epochs (iterations if algorithm==MiniBatchAlgorithm.Basic_MU) across the full
381-
sample set to be performed.
379+
Maximum number of epochs across the full sample set to be performed.
382380
sag_lambda: float, default = 0.2
383381
Exponential forgetting factor for for the stochastic _average_ gradient updates, i.e.
384382
MiniBatchAlgorithm.ASAG_MU and MiniBatchAlgorithm.GSAG_MU
@@ -416,7 +414,6 @@ def fit_minibatches(
416414
batches = list(_compute_sequential_minibatches(len(self._V), batch_size))
417415

418416
epoch_update = {
419-
MiniBatchAlgorithm.Basic_MU: self._epoch_update_algorithm_3,
420417
MiniBatchAlgorithm.Cyclic_MU: self._epoch_update_algorithm_4,
421418
MiniBatchAlgorithm.ASG_MU: self._epoch_update_algorithm_5,
422419
MiniBatchAlgorithm.GSG_MU: self._epoch_update_algorithm_6,
@@ -441,8 +438,7 @@ def fit_minibatches(
441438
if not progress_callback(self, epoch):
442439
break
443440
else:
444-
self._logger.info(f"{'Iteration' if algorithm == MiniBatchAlgorithm.Basic_MU else 'Epoch'}: {epoch}\t"
445-
f"Energy function: {self._energy_function()}")
441+
self._logger.info(f"{'Epoch'}: {epoch}\tEnergy function: {self._energy_function()}")
446442

447443
self._logger.info("MiniBatch TNMF finished.")
448444

@@ -459,12 +455,6 @@ def _accumulate_gradient_W(self, gradW_neg, gradW_pos, sag_lambda: float, s: sli
459455

460456
return gradW_neg, gradW_pos
461457

462-
def _epoch_update_algorithm_3(self, _, ___, args_update_H, __):
463-
# update H for all samples
464-
self._update_H(**args_update_H)
465-
# update W after processing all batches using all samples
466-
self._update_W()
467-
468458
def _epoch_update_algorithm_4(self, _, batches, args_update_H, __):
469459
gradW_neg, gradW_pos = 0, 0
470460
for batch in batches:

tnmf/tests/test_minibatch.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# hard-coded expected energy levels for the different algorithms
1818
expected_energies = {
19-
MiniBatchAlgorithm.Basic_MU: 14434.02658,
19+
'full_batch': 14434.02658,
2020
MiniBatchAlgorithm.Cyclic_MU: 14434.02658,
2121
MiniBatchAlgorithm.ASG_MU: 4558.86695,
2222
MiniBatchAlgorithm.GSG_MU: 14223.14454,
@@ -57,13 +57,20 @@ def fit_nmf(backend, algorithm):
5757
verbose=3,
5858
reconstruction_mode='valid',
5959
)
60-
nmf.fit_minibatches(
61-
V,
62-
sparsity_H=0.1,
63-
algorithm=algorithm,
64-
batch_size=3,
65-
n_epochs=5,
66-
sag_lambda=0.8)
60+
if isinstance(algorithm, MiniBatchAlgorithm):
61+
nmf.fit_minibatches(
62+
V,
63+
sparsity_H=0.1,
64+
algorithm=algorithm,
65+
batch_size=3,
66+
n_epochs=5,
67+
sag_lambda=0.8)
68+
else:
69+
nmf.fit_batch(
70+
V,
71+
sparsity_H=0.1,
72+
n_iterations=5,
73+
)
6774

6875
return nmf
6976

tnmf/tests/test_stream.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# hard-coded expected energy levels for the different algorithms
2020
expected_energies = {
2121
# no need to test all algorithms here
22-
MiniBatchAlgorithm.Basic_MU: 136.84096550,
2322
# MiniBatchAlgorithm.Cyclic_MU: 136.8409655,
2423
# MiniBatchAlgorithm.ASG_MU: 97.0072791,
2524
# MiniBatchAlgorithm.GSG_MU: 136.43285833,
@@ -98,7 +97,7 @@ def test_with_generator_limited():
9897
nmf.fit(
9998
V,
10099
sparsity_H=0.1,
101-
algorithm=MiniBatchAlgorithm.Basic_MU,
100+
algorithm=MiniBatchAlgorithm.Cyclic_MU,
102101
subsample_size=50,
103102
max_subsamples=5,
104103
batch_size=3,

0 commit comments

Comments
 (0)