diff --git a/main/ejml-simple/src/org/ejml/simple/AutomaticSimpleMatrixConvert.java b/main/ejml-simple/src/org/ejml/simple/AutomaticSimpleMatrixConvert.java index 93c933e6..5cc16a3c 100644 --- a/main/ejml-simple/src/org/ejml/simple/AutomaticSimpleMatrixConvert.java +++ b/main/ejml-simple/src/org/ejml/simple/AutomaticSimpleMatrixConvert.java @@ -30,20 +30,20 @@ public class AutomaticSimpleMatrixConvert { MatrixType commonType; - public void specify0( SimpleBase a, SimpleBase... inputs ) { - SimpleBase array[] = new SimpleBase[inputs.length + 1]; + public void specify0( ConstMatrix a, ConstMatrix... inputs ) { + var array = new ConstMatrix[inputs.length + 1]; System.arraycopy(inputs, 0, array, 0, inputs.length); array[inputs.length] = a; specify(array); } - public void specify( SimpleBase... inputs ) { + public void specify( ConstMatrix... inputs ) { boolean dense = false; boolean real = true; int bits = 32; - for (SimpleBase s : inputs) { - MatrixType t = s.mat.getType(); + for (ConstMatrix s : inputs) { + MatrixType t = s.getType(); if (t.isDense()) dense = true; if (!t.isReal()) @@ -55,7 +55,7 @@ public void specify( SimpleBase... inputs ) { commonType = MatrixType.lookup(dense, real, bits); } - public > T convert( SimpleBase matrix ) { + public > T convert( SimpleBase matrix ) { if (matrix.getType() == commonType) return (T)matrix; diff --git a/main/ejml-simple/src/org/ejml/simple/ConstMatrix.java b/main/ejml-simple/src/org/ejml/simple/ConstMatrix.java index f704e082..7eb155fe 100644 --- a/main/ejml-simple/src/org/ejml/simple/ConstMatrix.java +++ b/main/ejml-simple/src/org/ejml/simple/ConstMatrix.java @@ -36,6 +36,9 @@ * local copy that can't be accessed externally. *

* + *

NOTE: Implementations of ConstMatrix must extend {@link SimpleBase} or else it won't work when given as + * input to any class based off of {@link SimpleBase}.

+ * * @author Peter Abeles */ public interface ConstMatrix> { @@ -65,7 +68,7 @@ public interface ConstMatrix> { * where c is the returned matrix, a is this matrix, and b is the passed in matrix. *

* - * @param B A matrix that is n by bn. Not modified. + * @param B A matrix that is n by p. Not modified. * @return The results of this operation. * @see CommonOps_DDRM#mult(DMatrix1Row, DMatrix1Row, DMatrix1Row) */ @@ -98,7 +101,6 @@ public interface ConstMatrix> { */ T plus( ConstMatrix B ); - /** *

* Returns the result of matrix subtraction:
@@ -141,7 +143,6 @@ public interface ConstMatrix> { */ T minusComplex( double real, double imag ); - /** *

* Returns the result of scalar addition:
@@ -767,6 +768,28 @@ default int getNumElements() { */ T cols( int begin, int end ); + /** + *

Concatenates all the matrices together along their columns. If the rows do not match the upper elements + * are set to zero.

+ * + * A = [ this, m[0] , ... , m[n-1] ] + * + * @param matrices Set of matrices + * @return Resulting matrix + */ + T concatColumns( ConstMatrix... matrices ); + + /** + *

Concatenates all the matrices together along their columns. If the rows do not match the upper elements + * are set to zero.

+ * + * A = [ this; m[0] ; ... ; m[n-1] ] + * + * @param matrices Set of matrices + * @return Resulting matrix + */ + T concatRows( ConstMatrix... matrices ); + /** * Returns the type of matrix it is wrapping. */ diff --git a/main/ejml-simple/src/org/ejml/simple/SimpleBase.java b/main/ejml-simple/src/org/ejml/simple/SimpleBase.java index f855442d..ec712bfd 100644 --- a/main/ejml-simple/src/org/ejml/simple/SimpleBase.java +++ b/main/ejml-simple/src/org/ejml/simple/SimpleBase.java @@ -211,7 +211,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) { } /** {@inheritDoc} */ - @Override public T plus( ConstMatrix _B ) { + @Override public T plus( ConstMatrix _B ) { T B = (T)_B; convertType.specify(this, B); T A = convertType.convert(this); @@ -279,7 +279,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) { } /** {@inheritDoc} */ - @Override public T plus( double beta, ConstMatrix _B ) { + @Override public T plus( double beta, ConstMatrix _B ) { T B = (T)_B; convertType.specify(this, B); T A = convertType.convert(this); @@ -291,7 +291,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) { } /** {@inheritDoc} */ - @Override public double dot( ConstMatrix _v ) { + @Override public double dot( ConstMatrix _v ) { T v = (T)_v; convertType.specify(this, v); T A = convertType.convert(this); @@ -581,7 +581,7 @@ public void setRow( int row, int startColumn, double... values ) { * @param row Row in 'this' * @param src Vector which is to be copied into the row */ - public void setRow( int row, SimpleMatrix src ) { + public void setRow( int row, ConstMatrix src ) { if (!src.isVector()) throw new IllegalArgumentException("Input matrix must be a vector"); if (src.getNumElements() != numCols()) @@ -596,9 +596,10 @@ public void setRow( int row, SimpleMatrix src ) { } // See if it's a row or column vector and grab the appropriate elements. - double[] vector = src.numRows() < src.numCols() ? - src.ops.getRow(src.mat, 0, 0, src.getNumElements()) : - src.ops.getColumn(src.mat, 0, 0, src.getNumElements()); + var bsrc = (SimpleBase)src; + double[] vector = src.getNumRows() < src.getNumCols() ? + bsrc.ops.getRow(bsrc.mat, 0, 0, src.getNumElements()) : + bsrc.ops.getColumn(bsrc.mat, 0, 0, src.getNumElements()); // If src is real but output is complex, convert the vector. if (src.getType().isReal() && !getType().isReal()) { @@ -631,7 +632,7 @@ public void setColumn( int column, int startRow, double... values ) { * @param column Column in 'this' * @param src Vector which is to be copied into the column */ - public void setColumn( int column, SimpleMatrix src ) { + public void setColumn( int column, ConstMatrix src ) { if (!src.isVector()) throw new IllegalArgumentException("Input matrix must be a vector"); if (src.getNumElements() != numRows()) @@ -646,9 +647,10 @@ public void setColumn( int column, SimpleMatrix src ) { } // See if it's a row or column vector and grab the appropriate elements. - double[] vector = src.numRows() < src.numCols() ? - src.ops.getRow(src.mat, 0, 0, src.getNumElements()) : - src.ops.getColumn(src.mat, 0, 0, src.getNumElements()); + var bsrc = (SimpleBase)src; + double[] vector = src.getNumRows() < src.getNumCols() ? + bsrc.ops.getRow(bsrc.mat, 0, 0, src.getNumElements()) : + bsrc.ops.getColumn(bsrc.mat, 0, 0, src.getNumElements()); // If src is real but output is complex, convert the vector. if (src.getType().isReal() && !getType().isReal()) { @@ -1217,24 +1219,16 @@ public void printDimensions() { return mat.getType().getBits(); } - /** - *

Concatenates all the matrices together along their columns. If the rows do not match the upper elements - * are set to zero.

- * - * A = [ this, m[0] , ... , m[n-1] ] - * - * @param matrices Set of matrices - * @return Resulting matrix - */ - public T concatColumns( SimpleBase... matrices ) { + /** {@inheritDoc} */ + @Override public T concatColumns( ConstMatrix... matrices ) { convertType.specify0(this, matrices); T A = convertType.convert(this); int numCols = A.numCols(); int numRows = A.numRows(); for (int i = 0; i < matrices.length; i++) { - numRows = Math.max(numRows, matrices[i].numRows()); - numCols += matrices[i].numCols(); + numRows = Math.max(numRows, matrices[i].getNumRows()); + numCols += matrices[i].getNumCols(); } SimpleMatrix combined = SimpleMatrix.wrap(convertType.commonType.create(numRows, numCols)); @@ -1242,7 +1236,7 @@ public T concatColumns( SimpleBase... matrices ) { A.ops.extract(A.mat, 0, A.numRows(), 0, A.numCols(), combined.mat, 0, 0); int col = A.numCols(); for (int i = 0; i < matrices.length; i++) { - Matrix m = convertType.convert(matrices[i]).mat; + Matrix m = convertType.convert((SimpleBase)matrices[i]).mat; int cols = m.getNumCols(); int rows = m.getNumRows(); A.ops.extract(m, 0, rows, 0, cols, combined.mat, 0, col); @@ -1252,24 +1246,16 @@ public T concatColumns( SimpleBase... matrices ) { return (T)combined; } - /** - *

Concatenates all the matrices together along their columns. If the rows do not match the upper elements - * are set to zero.

- * - * A = [ this; m[0] ; ... ; m[n-1] ] - * - * @param matrices Set of matrices - * @return Resulting matrix - */ - public T concatRows( SimpleBase... matrices ) { + /** {@inheritDoc} */ + @Override public T concatRows( ConstMatrix... matrices ) { convertType.specify0(this, matrices); T A = convertType.convert(this); int numCols = A.numCols(); int numRows = A.numRows(); for (int i = 0; i < matrices.length; i++) { - numRows += matrices[i].numRows(); - numCols = Math.max(numCols, matrices[i].numCols()); + numRows += matrices[i].getNumRows(); + numCols = Math.max(numCols, matrices[i].getNumCols()); } SimpleMatrix combined = SimpleMatrix.wrap(convertType.commonType.create(numRows, numCols)); @@ -1277,7 +1263,7 @@ public T concatRows( SimpleBase... matrices ) { A.ops.extract(A.mat, 0, A.numRows(), 0, A.numCols(), combined.mat, 0, 0); int row = A.numRows(); for (int i = 0; i < matrices.length; i++) { - Matrix m = convertType.convert(matrices[i]).mat; + Matrix m = convertType.convert((SimpleBase)matrices[i]).mat; int cols = m.getNumCols(); int rows = m.getNumRows(); A.ops.extract(m, 0, rows, 0, cols, combined.mat, row, 0);