@@ -48,7 +48,6 @@ class MiniBatchAlgorithm(Enum):
48
48
r"""
49
49
MiniBatch algorithms that can be used with :meth:`.TransformInvariantNMF.fit_minibatch`.
50
50
"""
51
- Basic_MU = 3 # Algorithm 3 Basic alternating scheme for MU rules
52
51
Cyclic_MU = 4 # Algorithm 4 Cyclic mini-batch for MU rules
53
52
ASG_MU = 5 # Algorithm 5 Asymmetric SG mini-batch MU rules (ASG-MU)
54
53
GSG_MU = 6 # Algorithm 6 Greedy SG mini-batch MU rules (GSG-MU)
@@ -352,9 +351,9 @@ def fit_batch(
352
351
def fit_minibatches (
353
352
self ,
354
353
V : np .ndarray ,
355
- algorithm : MiniBatchAlgorithm = MiniBatchAlgorithm .Basic_MU ,
354
+ algorithm : MiniBatchAlgorithm = MiniBatchAlgorithm .ASG_MU ,
356
355
batch_size : int = 3 ,
357
- n_epochs : int = 1000 , # corresponds to max_iter if algorithm == MiniBatchAlgorithm.Basic_MU
356
+ n_epochs : int = 1000 ,
358
357
sag_lambda : float = 0.2 ,
359
358
keep_W : bool = False ,
360
359
sparsity_H : float = 0. ,
@@ -375,10 +374,9 @@ def fit_minibatches(
375
374
algorithm: MiniBatchAlgorithm
376
375
MiniBatch update scheme to be used. See :class:`MiniBatchAlgorithm` and [3]_ for the different choices.
377
376
batch_size: int, default = 3
378
- Number of samples per mini batch. Ignored if algorithm==MiniBatchAlgorithm.Basic_MU
377
+ Number of samples per mini batch.
379
378
n_epochs: int, default = 1000
380
- Maximum number of epochs (iterations if algorithm==MiniBatchAlgorithm.Basic_MU) across the full
381
- sample set to be performed.
379
+ Maximum number of epochs across the full sample set to be performed.
382
380
sag_lambda: float, default = 0.2
383
381
Exponential forgetting factor for for the stochastic _average_ gradient updates, i.e.
384
382
MiniBatchAlgorithm.ASAG_MU and MiniBatchAlgorithm.GSAG_MU
@@ -416,7 +414,6 @@ def fit_minibatches(
416
414
batches = list (_compute_sequential_minibatches (len (self ._V ), batch_size ))
417
415
418
416
epoch_update = {
419
- MiniBatchAlgorithm .Basic_MU : self ._epoch_update_algorithm_3 ,
420
417
MiniBatchAlgorithm .Cyclic_MU : self ._epoch_update_algorithm_4 ,
421
418
MiniBatchAlgorithm .ASG_MU : self ._epoch_update_algorithm_5 ,
422
419
MiniBatchAlgorithm .GSG_MU : self ._epoch_update_algorithm_6 ,
@@ -441,8 +438,7 @@ def fit_minibatches(
441
438
if not progress_callback (self , epoch ):
442
439
break
443
440
else :
444
- self ._logger .info (f"{ 'Iteration' if algorithm == MiniBatchAlgorithm .Basic_MU else 'Epoch' } : { epoch } \t "
445
- f"Energy function: { self ._energy_function ()} " )
441
+ self ._logger .info (f"{ 'Epoch' } : { epoch } \t Energy function: { self ._energy_function ()} " )
446
442
447
443
self ._logger .info ("MiniBatch TNMF finished." )
448
444
@@ -459,12 +455,6 @@ def _accumulate_gradient_W(self, gradW_neg, gradW_pos, sag_lambda: float, s: sli
459
455
460
456
return gradW_neg , gradW_pos
461
457
462
- def _epoch_update_algorithm_3 (self , _ , ___ , args_update_H , __ ):
463
- # update H for all samples
464
- self ._update_H (** args_update_H )
465
- # update W after processing all batches using all samples
466
- self ._update_W ()
467
-
468
458
def _epoch_update_algorithm_4 (self , _ , batches , args_update_H , __ ):
469
459
gradW_neg , gradW_pos = 0 , 0
470
460
for batch in batches :
0 commit comments