Skip to content

Commit 29032a7

Browse files
committed
ML-385 code refactoring, added ValidatesShapes
1 parent 5d5bf45 commit 29032a7

File tree

6 files changed

+62
-39
lines changed

6 files changed

+62
-39
lines changed

src/NeuralNet/CostFunctions/CrossEntropy/CrossEntropy.php

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
namespace Rubix\ML\NeuralNet\CostFunctions\CrossEntropy;
66

7-
use NumPower;
87
use NDArray;
9-
use Rubix\ML\Exceptions\InvalidArgumentException;
8+
use NumPower;
109
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\ClassificationLoss;
11-
10+
use Rubix\ML\Traits\ValidatesShapes;
1211
use const Rubix\ML\EPSILON;
1312

1413
/**
@@ -28,6 +27,8 @@
2827
*/
2928
class CrossEntropy implements ClassificationLoss
3029
{
30+
use ValidatesShapes;
31+
3132
/**
3233
* Compute the loss score.
3334
*
@@ -39,9 +40,7 @@ class CrossEntropy implements ClassificationLoss
3940
*/
4041
public function compute(NDArray $output, NDArray $target) : float
4142
{
42-
if ($output->shape() !== $target->shape()) {
43-
throw new InvalidArgumentException('Output and target must have the same shape.');
44-
}
43+
$this->validateShapes($output, $target);
4544

4645
// Clip values to avoid log(0)
4746
$output = NumPower::clip($output, EPSILON, 1.0);
@@ -64,9 +63,7 @@ public function compute(NDArray $output, NDArray $target) : float
6463
*/
6564
public function differentiate(NDArray $output, NDArray $target) : NDArray
6665
{
67-
if ($output->shape() !== $target->shape()) {
68-
throw new InvalidArgumentException('Output and target must have the same shape.');
69-
}
66+
$this->validateShapes($output, $target);
7067

7168
// Numerator = ŷ - y (calculate before clipping to preserve zeros)
7269
$numerator = NumPower::subtract($output, $target);

src/NeuralNet/CostFunctions/HuberLoss/HuberLoss.php

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
use NDArray;
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
10-
use Rubix\ML\Exceptions\InvalidArgumentException;
1110
use Rubix\ML\NeuralNet\CostFunctions\HuberLoss\Exceptions\InvalidAlphaException;
11+
use Rubix\ML\Traits\ValidatesShapes;
1212

1313
/**
1414
* Huber Loss
@@ -25,6 +25,8 @@
2525
*/
2626
class HuberLoss implements RegressionLoss
2727
{
28+
use ValidatesShapes;
29+
2830
/**
2931
* The alpha quantile i.e the pivot point at which numbers larger will be
3032
* evalutated with an L1 loss while number smaller will be evalutated with
@@ -68,9 +70,7 @@ public function __construct(float $alpha = 0.9)
6870
*/
6971
public function compute(NDArray $output, NDArray $target) : float
7072
{
71-
if ($output->shape() !== $target->shape()) {
72-
throw new InvalidArgumentException('Output and target must have the same shape.');
73-
}
73+
$this->validateShapes($output, $target);
7474

7575
$difference = NumPower::subtract($target, $output);
7676
$scaled = NumPower::divide($difference, $this->alpha);
@@ -94,9 +94,7 @@ public function compute(NDArray $output, NDArray $target) : float
9494
*/
9595
public function differentiate(NDArray $output, NDArray $target) : NDArray
9696
{
97-
if ($output->shape() !== $target->shape()) {
98-
throw new InvalidArgumentException('Output and target must have the same shape.');
99-
}
97+
$this->validateShapes($output, $target);
10098

10199
$difference = NumPower::subtract($output, $target);
102100
$squared = NumPower::pow($difference, 2);

src/NeuralNet/CostFunctions/LeastSquares/LeastSquares.php

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
namespace Rubix\ML\NeuralNet\CostFunctions\LeastSquares;
66

7-
use InvalidArgumentException;
87
use NDArray;
98
use NumPower;
109
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
10+
use Rubix\ML\Traits\ValidatesShapes;
1111

1212
/**
1313
* Least Squares
@@ -22,6 +22,8 @@
2222
*/
2323
class LeastSquares implements RegressionLoss
2424
{
25+
use ValidatesShapes;
26+
2527
/**
2628
* Compute the loss score.
2729
*
@@ -33,9 +35,7 @@ class LeastSquares implements RegressionLoss
3335
*/
3436
public function compute(NDArray $output, NDArray $target) : float
3537
{
36-
if ($output->shape() !== $target->shape()) {
37-
throw new InvalidArgumentException('Output and target must have the same shape.');
38-
}
38+
$this->validateShapes($output, $target);
3939

4040
$difference = NumPower::subtract($output, $target);
4141
$squared = NumPower::pow($difference, 2);
@@ -55,9 +55,7 @@ public function compute(NDArray $output, NDArray $target) : float
5555
*/
5656
public function differentiate(NDArray $output, NDArray $target) : NDArray
5757
{
58-
if ($output->shape() !== $target->shape()) {
59-
throw new InvalidArgumentException('Output and target must have the same shape.');
60-
}
58+
$this->validateShapes($output, $target);
6159

6260
return NumPower::subtract($output, $target);
6361
}

src/NeuralNet/CostFunctions/MeanAbsoluteError/MeanAbsoluteError.php

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
namespace Rubix\ML\NeuralNet\CostFunctions\MeanAbsoluteError;
66

7-
use InvalidArgumentException;
87
use NDArray;
98
use NumPower;
109
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
10+
use Rubix\ML\Traits\ValidatesShapes;
1111

1212
/**
1313
* Mean Absolute Error
@@ -23,6 +23,8 @@
2323
*/
2424
class MeanAbsoluteError implements RegressionLoss
2525
{
26+
use ValidatesShapes;
27+
2628
/**
2729
* Compute the loss score.
2830
*
@@ -34,9 +36,7 @@ class MeanAbsoluteError implements RegressionLoss
3436
*/
3537
public function compute(NDArray $output, NDArray $target) : float
3638
{
37-
if ($output->shape() !== $target->shape()) {
38-
throw new InvalidArgumentException('Output and target must have the same shape.');
39-
}
39+
$this->validateShapes($output, $target);
4040

4141
$difference = NumPower::subtract($output, $target);
4242
$absolute = NumPower::abs($difference);
@@ -55,9 +55,7 @@ public function compute(NDArray $output, NDArray $target) : float
5555
*/
5656
public function differentiate(NDArray $output, NDArray $target) : NDArray
5757
{
58-
if ($output->shape() !== $target->shape()) {
59-
throw new InvalidArgumentException('Output and target must have the same shape.');
60-
}
58+
$this->validateShapes($output, $target);
6159

6260
$difference = NumPower::subtract($output, $target);
6361

src/NeuralNet/CostFunctions/RelativeEntropy/RelativeEntropy.php

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
namespace Rubix\ML\NeuralNet\CostFunctions\RelativeEntropy;
66

7-
use NumPower;
87
use NDArray;
9-
use Rubix\ML\Exceptions\InvalidArgumentException;
8+
use NumPower;
109
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\ClassificationLoss;
11-
10+
use Rubix\ML\Traits\ValidatesShapes;
1211
use const Rubix\ML\EPSILON;
1312

1413
/**
@@ -24,6 +23,8 @@
2423
*/
2524
class RelativeEntropy implements ClassificationLoss
2625
{
26+
use ValidatesShapes;
27+
2728
/**
2829
* Compute the loss.
2930
*
@@ -37,9 +38,7 @@ class RelativeEntropy implements ClassificationLoss
3738
*/
3839
public function compute(NDArray $output, NDArray $target) : float
3940
{
40-
if ($output->shape() !== $target->shape()) {
41-
throw new InvalidArgumentException('Output and target must have the same shape.');
42-
}
41+
$this->validateShapes($output, $target);
4342

4443
// Clip values to avoid log(0)
4544
$target = NumPower::clip($target, EPSILON, 1.0);
@@ -65,9 +64,7 @@ public function compute(NDArray $output, NDArray $target) : float
6564
*/
6665
public function differentiate(NDArray $output, NDArray $target) : NDArray
6766
{
68-
if ($output->shape() !== $target->shape()) {
69-
throw new InvalidArgumentException('Output and target must have the same shape.');
70-
}
67+
$this->validateShapes($output, $target);
7168

7269
// Clip values to avoid division by zero
7370
$target = NumPower::clip($target, EPSILON, 1.0);

src/Traits/ValidatesShapes.php

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Rubix\ML\Traits;
6+
7+
use InvalidArgumentException;
8+
use NDArray;
9+
10+
/**
11+
* Validates Shapes
12+
*
13+
* A trait that provides shape validation for cost functions to ensure
14+
* output and target arrays have matching dimensions.
15+
*
16+
* @category Machine Learning
17+
* @package Rubix/ML
18+
* @author Samuel Akopyan <[email protected]>
19+
*/
20+
trait ValidatesShapes
21+
{
22+
/**
23+
* Validate that output and target have the same shape.
24+
*
25+
* @param NDArray $output
26+
* @param NDArray $target
27+
* @throws InvalidArgumentException
28+
*/
29+
protected function validateShapes(NDArray $output, NDArray $target) : void
30+
{
31+
if ($output->shape() !== $target->shape()) {
32+
throw new InvalidArgumentException('Output and target must have the same shape.');
33+
}
34+
}
35+
}

0 commit comments

Comments
 (0)