@@ -233,7 +233,7 @@ def optimize_loop(self):
233
233
self .num_updates += 1
234
234
self .residuals = self .get_residual_matrix ()
235
235
self .objective_function = self .get_objective_function ()
236
- print (f"Objective function after updateX : { self .objective_function :.5e} " )
236
+ print (f"Objective function after update_comps : { self .objective_function :.5e} " )
237
237
self ._objective_history .append (self .objective_function )
238
238
if self .objective_difference is None :
239
239
self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
@@ -243,7 +243,7 @@ def optimize_loop(self):
243
243
self .num_updates += 1
244
244
self .residuals = self .get_residual_matrix ()
245
245
self .objective_function = self .get_objective_function ()
246
- print (f"Objective function after updateY2 : { self .objective_function :.5e} " )
246
+ print (f"Objective function after update_weights : { self .objective_function :.5e} " )
247
247
self ._objective_history .append (self .objective_function )
248
248
249
249
# Now we update stretch
@@ -266,14 +266,16 @@ def apply_interpolation(self, a, x, return_derivatives=False):
266
266
a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
267
267
268
268
# Compute fractional indices, broadcasting over `a`
269
- ii = np .arange (x_len )[:, None ] / a # Shape (N, M)
269
+ fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
270
270
271
- II = np .floor (ii ).astype (int ) # Integer part (still (N, M))
272
- valid_mask = II < (x_len - 1 ) # Ensure indices are within bounds
271
+ integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
272
+ valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
273
273
274
274
# Apply valid_mask to keep correct indices
275
- idx_int = np .where (valid_mask , II , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
276
- idx_frac = np .where (valid_mask , ii , II ) # Keep aligned (previously "i")
275
+ idx_int = np .where (
276
+ valid_mask , integer_indices , x_len - 2
277
+ ) # Prevent out-of-bounds indexing (previously "I")
278
+ idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
277
279
278
280
# Ensure x is a 1D array
279
281
x = np .asarray (x ).ravel ()
@@ -351,7 +353,7 @@ def apply_interpolation_matrix(self, comps=None, weights=None, stretch=None, ret
351
353
stretch_tiled = np .tile (stretch_flat , (self ._signal_len , 1 ))
352
354
353
355
# Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch)
354
- ii = (
356
+ fractional_indices = (
355
357
np .tile (np .arange (self ._signal_len )[:, None ], (1 , self ._num_conditions * self ._n_components ))
356
358
* stretch_tiled
357
359
)
@@ -368,44 +370,45 @@ def apply_interpolation_matrix(self, comps=None, weights=None, stretch=None, ret
368
370
).reshape (self ._signal_len , self ._n_components * self ._num_conditions )
369
371
370
372
# Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)])
371
- X1 = np .vstack ([comps , comps [- 1 , :]]) # Duplicate last row (like MATLAB)
373
+ comps_bounded = np .vstack ([comps , comps [- 1 , :]]) # Duplicate last row (like MATLAB)
372
374
373
375
# Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1))
374
- II = np .floor (ii ).astype (int )
376
+ floor_indices = np .floor (fractional_indices ).astype (int )
375
377
376
- II1 = np .minimum (II + 1 , self ._signal_len )
377
- II2 = np .minimum (II1 + 1 , self ._signal_len )
378
+ floor_ind_1 = np .minimum (floor_indices + 1 , self ._signal_len )
379
+ floor_ind_2 = np .minimum (floor_ind_1 + 1 , self ._signal_len )
378
380
379
381
# Compute fractional part (MATLAB: iI = ii - II)
380
- iI = ii - II
382
+ fractional_floor_indices = fractional_indices - floor_indices
381
383
382
384
# Compute offset indices (MATLAB: II1_ = II1 + bias; II2_ = II2 + bias)
383
- II1_ = II1 + bias
384
- II2_ = II2 + bias
385
+ offset_floor_ind_1 = floor_ind_1 + bias
386
+ offset_floor_ind_2 = floor_ind_2 + bias
385
387
386
388
# Extract values (MATLAB: XI1 = reshape(X1(II1_), N, K*M); XI2 = reshape(X1(II2_), N, K*M))
387
389
# Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
388
- XI1 = X1 .flatten (order = "F" )[(II1_ - 1 ).ravel ()].reshape (
390
+ # order = F uses FORTRAN, column major order
391
+ comps_val_1 = comps_bounded .flatten (order = "F" )[(offset_floor_ind_1 - 1 ).ravel ()].reshape (
389
392
self ._signal_len , self ._n_components * self ._num_conditions
390
- ) # order = F uses FORTRAN, column major order
391
- XI2 = X1 .flatten (order = "F" )[(II2_ - 1 ).ravel ()].reshape (
393
+ )
394
+ comps_val_2 = comps_bounded .flatten (order = "F" )[(offset_floor_ind_2 - 1 ).ravel ()].reshape (
392
395
self ._signal_len , self ._n_components * self ._num_conditions
393
396
)
394
397
395
398
# Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_comps=Ax2.*YY)
396
- Ax2 = XI1 * (1 - iI ) + XI2 * iI
397
- stretched_comps = Ax2 * weights_tiled # Apply weighting
399
+ stretch_comps2 = comps_val_1 * (1 - fractional_floor_indices ) + comps_val_2 * fractional_floor_indices
400
+ stretched_comps = stretch_comps2 * weights_tiled # Apply weighting
398
401
399
402
if return_derivatives :
400
403
# Compute first derivative (MATLAB: Tx2=XI1.*(-di)+XI2.*di; d_str_cmps=Tx2.*YY)
401
- di = - ii * stretch_tiled
402
- d_x2 = XI1 * (- di ) + XI2 * di
403
- d_str_cmps = d_x2 * weights_tiled
404
+ di = - fractional_indices * stretch_tiled
405
+ d_comps2 = comps_val_1 * (- di ) + comps_val_2 * di
406
+ d_str_cmps = d_comps2 * weights_tiled
404
407
405
408
# Compute second derivative (MATLAB: Hx2=XI1.*(-ddi)+XI2.*ddi; dd_str_comps=Hx2.*YY)
406
409
ddi = - di * stretch_tiled * 2
407
- dd_x2 = XI1 * (- ddi ) + XI2 * ddi
408
- dd_str_cmps = dd_x2 * weights_tiled
410
+ dd_comps2 = comps_val_1 * (- ddi ) + comps_val_2 * ddi
411
+ dd_str_cmps = dd_comps2 * weights_tiled
409
412
else :
410
413
shape = stretched_comps .shape
411
414
d_str_cmps = np .empty (shape )
@@ -430,13 +433,17 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
430
433
K = weights .shape [0 ]
431
434
432
435
# Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1))
433
- AA = np .tile (stretch .reshape (1 , M * K , order = "F" ) ** - 1 , (N , 1 ))
436
+ stretch_tiled = np .tile (
437
+ stretch .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ) ** - 1 , (self ._signal_len , 1 )
438
+ )
434
439
435
440
# Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA)
436
- ii = np .arange (N )[:, None ] * AA # Shape (N, M*K), replacing `index`
441
+ ii = np .arange (self . _signal_len )[:, None ] * stretch_tiled # Shape (N, M*K), replacing `index`
437
442
438
443
# Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1))
439
- YY = np .tile (weights .reshape (1 , M * K , order = "F" ), (N , 1 ))
444
+ weights_tiled = np .tile (
445
+ weights .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ), (self ._signal_len , 1 )
446
+ )
440
447
441
448
# Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1))
442
449
II = np .floor (ii ).astype (int )
@@ -448,7 +455,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
448
455
II2_ = II2
449
456
450
457
# Compute fractional part (MATLAB: iI = ii - II)
451
- iI = ii - II
458
+ fractional_indices = ii - II
452
459
453
460
# Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
454
461
repm = np .tile (np .arange (K ), (N , M ))
@@ -457,12 +464,14 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
457
464
kron = np .kron (residuals , np .ones ((1 , K )))
458
465
459
466
# (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY)
460
- kron_iI = kron * iI
461
- iIYY = (iI - 1 ) * YY
467
+ kron_iI = kron * fractional_indices
468
+ iIYY = (fractional_indices - 1 ) * weights_tiled
462
469
463
470
# Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K))
464
471
x2 = coo_matrix (((- kron * iIYY ).flatten (), (II1_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )).tocsc ()
465
- x3 = coo_matrix (((kron_iI * YY ).flatten (), (II2_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )).tocsc ()
472
+ x3 = coo_matrix (
473
+ ((kron_iI * weights_tiled ).flatten (), (II2_ .flatten () - 1 , repm .flatten ())), shape = (N + 1 , K )
474
+ ).tocsc ()
466
475
467
476
# Combine the last row into previous, then remove the last row
468
477
x2 [N - 1 , :] += x2 [N , :]
@@ -527,46 +536,53 @@ def hess(y):
527
536
528
537
def update_comps (self ):
529
538
"""
530
- Updates `comps` using gradient-based optimization with adaptive step size L .
539
+ Updates `comps` using gradient-based optimization with adaptive step size step_size .
531
540
"""
532
541
# Compute `stretched_comps` using the interpolation function
533
542
stretched_comps , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs (derivatives)
534
543
# Compute RA and RR
535
- intermediate_RA = stretched_comps .flatten (order = "F" ).reshape (
544
+ intermediate_reshaped = stretched_comps .flatten (order = "F" ).reshape (
536
545
(self ._signal_len * self ._num_conditions , self ._n_components ), order = "F"
537
546
)
538
- RA = intermediate_RA .sum (axis = 1 ).reshape ((self ._signal_len , self ._num_conditions ), order = "F" )
539
- RR = RA - self .source_matrix
547
+ reshaped_stretched_components = intermediate_reshaped .sum (axis = 1 ).reshape (
548
+ (self ._signal_len , self ._num_conditions ), order = "F"
549
+ )
550
+ component_residuals = reshaped_stretched_components - self .source_matrix
540
551
# Compute gradient `GraX`
541
552
self .grad_comps = self .apply_transformation_matrix (
542
- residuals = RR
543
- ).toarray () # toarray equivalent of full, make non-sparse
553
+ residuals = component_residuals
554
+ ).toarray () # toarray equivalent of MATLAB " full", makes non-sparse
544
555
545
- # Compute initial step size `L0 `
546
- L0 = np .linalg .eigvalsh (self .weights .T @ self .weights ).max () * np .max (
556
+ # Compute initial step size `initial_step_size `
557
+ initial_step_size = np .linalg .eigvalsh (self .weights .T @ self .weights ).max () * np .max (
547
558
[self .stretch .max (), 1 / self .stretch .min ()]
548
559
)
549
- # Compute adaptive step size `L `
560
+ # Compute adaptive step size `step_size `
550
561
if self ._prev_comps is None :
551
- L = L0
562
+ step_size = initial_step_size
552
563
else :
553
564
num = np .sum (
554
565
(self .grad_comps - self ._prev_grad_comps ) * (self .comps - self ._prev_comps )
555
566
) # Elem-wise multiply
556
567
denom = np .linalg .norm (self .comps - self ._prev_comps , "fro" ) ** 2 # Frobenius norm squared
557
- L = num / denom if denom > 0 else L0
558
- if L <= 0 :
559
- L = L0
568
+ step_size = num / denom if denom > 0 else initial_step_size
569
+ if step_size <= 0 :
570
+ step_size = initial_step_size
560
571
561
572
# Store our old component matrix before updating because it is used in step selection
562
573
self ._prev_comps = self .comps .copy ()
563
574
564
575
while True : # iterate updating components
565
- comps_step = self ._prev_comps - self .grad_comps / L
576
+ comps_step = self ._prev_comps - self .grad_comps / step_size
566
577
# Solve x^3 + p*x + q = 0 for the largest real root
567
- self .comps = np .square (cubic_largest_real_root (- comps_step , self .eta / (2 * L )))
578
+ self .comps = np .square (cubic_largest_real_root (- comps_step , self .eta / (2 * step_size )))
568
579
# Mask values that should be set to zero
569
- mask = self .comps ** 2 * L / 2 - L * self .comps * comps_step + self .eta * np .sqrt (self .comps ) < 0
580
+ mask = (
581
+ self .comps ** 2 * step_size / 2
582
+ - step_size * self .comps * comps_step
583
+ + self .eta * np .sqrt (self .comps )
584
+ < 0
585
+ )
570
586
self .comps = mask * self .comps
571
587
572
588
objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
@@ -576,9 +592,9 @@ def update_comps(self):
576
592
# Check if objective function improves
577
593
if objective_improvement > 0 :
578
594
break
579
- # If not, increase L (step size)
580
- L *= 2
581
- if np .isinf (L ):
595
+ # If not, increase step_size (step size)
596
+ step_size *= 2
597
+ if np .isinf (step_size ):
582
598
break
583
599
584
600
def update_weights (self ):
@@ -587,7 +603,7 @@ def update_weights(self):
587
603
"""
588
604
589
605
for m in range (self ._num_conditions ):
590
- T = np .zeros ((self ._signal_len , self ._n_components )) # Initialize T as an (N, K) zero matrix
606
+ T = np .zeros ((self ._signal_len , self ._n_components ))
591
607
592
608
# Populate T using apply_interpolation
593
609
for k in range (self ._n_components ):
0 commit comments