Skip to content

Commit b520be6

Browse files
committed
ML-385 renamed ValidatesShapes into AssertsShapes::assertSameShape
1 parent c5899a6 commit b520be6

File tree

6 files changed

+27
-27
lines changed

6 files changed

+27
-27
lines changed

src/NeuralNet/CostFunctions/CrossEntropy/CrossEntropy.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use NDArray;
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\ClassificationLoss;
10-
use Rubix\ML\Traits\ValidatesShapes;
10+
use Rubix\ML\Traits\AssertsShapes;
1111
use const Rubix\ML\EPSILON;
1212

1313
/**
@@ -27,7 +27,7 @@
2727
*/
2828
class CrossEntropy implements ClassificationLoss
2929
{
30-
use ValidatesShapes;
30+
use AssertsShapes;
3131

3232
/**
3333
* Compute the loss score.
@@ -40,7 +40,7 @@ class CrossEntropy implements ClassificationLoss
4040
*/
4141
public function compute(NDArray $output, NDArray $target) : float
4242
{
43-
$this->validateShapes($output, $target);
43+
$this->assertSameShape($output, $target);
4444

4545
// Clip values to avoid log(0)
4646
$output = NumPower::clip($output, EPSILON, 1.0);
@@ -63,7 +63,7 @@ public function compute(NDArray $output, NDArray $target) : float
6363
*/
6464
public function differentiate(NDArray $output, NDArray $target) : NDArray
6565
{
66-
$this->validateShapes($output, $target);
66+
$this->assertSameShape($output, $target);
6767

6868
// Numerator = ŷ - y (calculate before clipping to preserve zeros)
6969
$numerator = NumPower::subtract($output, $target);

src/NeuralNet/CostFunctions/HuberLoss/HuberLoss.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
1010
use Rubix\ML\NeuralNet\CostFunctions\HuberLoss\Exceptions\InvalidAlphaException;
11-
use Rubix\ML\Traits\ValidatesShapes;
11+
use Rubix\ML\Traits\AssertsShapes;
1212

1313
/**
1414
* Huber Loss
@@ -25,7 +25,7 @@
2525
*/
2626
class HuberLoss implements RegressionLoss
2727
{
28-
use ValidatesShapes;
28+
use AssertsShapes;
2929

3030
/**
3131
* The alpha quantile i.e the pivot point at which numbers larger will be
@@ -70,7 +70,7 @@ public function __construct(float $alpha = 0.9)
7070
*/
7171
public function compute(NDArray $output, NDArray $target) : float
7272
{
73-
$this->validateShapes($output, $target);
73+
$this->assertSameShape($output, $target);
7474

7575
$difference = NumPower::subtract($target, $output);
7676
$scaled = NumPower::divide($difference, $this->alpha);
@@ -94,7 +94,7 @@ public function compute(NDArray $output, NDArray $target) : float
9494
*/
9595
public function differentiate(NDArray $output, NDArray $target) : NDArray
9696
{
97-
$this->validateShapes($output, $target);
97+
$this->assertSameShape($output, $target);
9898

9999
$difference = NumPower::subtract($output, $target);
100100
$squared = NumPower::pow($difference, 2);

src/NeuralNet/CostFunctions/LeastSquares/LeastSquares.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use NDArray;
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
10-
use Rubix\ML\Traits\ValidatesShapes;
10+
use Rubix\ML\Traits\AssertsShapes;
1111

1212
/**
1313
* Least Squares
@@ -22,7 +22,7 @@
2222
*/
2323
class LeastSquares implements RegressionLoss
2424
{
25-
use ValidatesShapes;
25+
use AssertsShapes;
2626

2727
/**
2828
* Compute the loss score.
@@ -35,7 +35,7 @@ class LeastSquares implements RegressionLoss
3535
*/
3636
public function compute(NDArray $output, NDArray $target) : float
3737
{
38-
$this->validateShapes($output, $target);
38+
$this->assertSameShape($output, $target);
3939

4040
$difference = NumPower::subtract($output, $target);
4141
$squared = NumPower::pow($difference, 2);
@@ -55,7 +55,7 @@ public function compute(NDArray $output, NDArray $target) : float
5555
*/
5656
public function differentiate(NDArray $output, NDArray $target) : NDArray
5757
{
58-
$this->validateShapes($output, $target);
58+
$this->assertSameShape($output, $target);
5959

6060
return NumPower::subtract($output, $target);
6161
}

src/NeuralNet/CostFunctions/MeanAbsoluteError/MeanAbsoluteError.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use NDArray;
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\RegressionLoss;
10-
use Rubix\ML\Traits\ValidatesShapes;
10+
use Rubix\ML\Traits\AssertsShapes;
1111

1212
/**
1313
* Mean Absolute Error
@@ -23,7 +23,7 @@
2323
*/
2424
class MeanAbsoluteError implements RegressionLoss
2525
{
26-
use ValidatesShapes;
26+
use AssertsShapes;
2727

2828
/**
2929
* Compute the loss score.
@@ -36,7 +36,7 @@ class MeanAbsoluteError implements RegressionLoss
3636
*/
3737
public function compute(NDArray $output, NDArray $target) : float
3838
{
39-
$this->validateShapes($output, $target);
39+
$this->assertSameShape($output, $target);
4040

4141
$difference = NumPower::subtract($output, $target);
4242
$absolute = NumPower::abs($difference);
@@ -55,7 +55,7 @@ public function compute(NDArray $output, NDArray $target) : float
5555
*/
5656
public function differentiate(NDArray $output, NDArray $target) : NDArray
5757
{
58-
$this->validateShapes($output, $target);
58+
$this->assertSameShape($output, $target);
5959

6060
$difference = NumPower::subtract($output, $target);
6161

src/NeuralNet/CostFunctions/RelativeEntropy/RelativeEntropy.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use NDArray;
88
use NumPower;
99
use Rubix\ML\NeuralNet\CostFunctions\Base\Contracts\ClassificationLoss;
10-
use Rubix\ML\Traits\ValidatesShapes;
10+
use Rubix\ML\Traits\AssertsShapes;
1111
use const Rubix\ML\EPSILON;
1212

1313
/**
@@ -23,7 +23,7 @@
2323
*/
2424
class RelativeEntropy implements ClassificationLoss
2525
{
26-
use ValidatesShapes;
26+
use AssertsShapes;
2727

2828
/**
2929
* Compute the loss.
@@ -38,7 +38,7 @@ class RelativeEntropy implements ClassificationLoss
3838
*/
3939
public function compute(NDArray $output, NDArray $target) : float
4040
{
41-
$this->validateShapes($output, $target);
41+
$this->assertSameShape($output, $target);
4242

4343
// Clip values to avoid log(0)
4444
$target = NumPower::clip($target, EPSILON, 1.0);
@@ -64,7 +64,7 @@ public function compute(NDArray $output, NDArray $target) : float
6464
*/
6565
public function differentiate(NDArray $output, NDArray $target) : NDArray
6666
{
67-
$this->validateShapes($output, $target);
67+
$this->assertSameShape($output, $target);
6868

6969
// Clip values to avoid division by zero
7070
$target = NumPower::clip($target, EPSILON, 1.0);

src/Traits/ValidatesShapes.php renamed to src/Traits/AssertsShapes.php

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
* @package Rubix/ML
1818
* @author Samuel Akopyan <[email protected]>
1919
*/
20-
trait ValidatesShapes
20+
trait AssertsShapes
2121
{
2222
/**
23-
* Validate that output and target have the same shape.
23+
* Assert that the output and target NDArrays have identical shapes.
2424
*
25-
* @param NDArray $output
26-
* @param NDArray $target
27-
* @throws InvalidArgumentException
25+
* @param NDArray $output The output array to check.
26+
* @param NDArray $target The target array to compare against.
27+
* @throws InvalidArgumentException If the shapes do not match.
2828
*/
29-
protected function validateShapes(NDArray $output, NDArray $target) : void
29+
protected function assertSameShape(NDArray $output, NDArray $target) : void
3030
{
3131
if ($output->shape() !== $target->shape()) {
32-
throw new InvalidArgumentException('Output and target must have the same shape.');
32+
throw new InvalidArgumentException('Output and target must have identical shapes.');
3333
}
3434
}
3535
}

0 commit comments

Comments
 (0)