Skip to content

Commit

Permalink
Using ConstMatrix in more locations where the matrix isn't modified.
Browse files Browse the repository at this point in the history
  • Loading branch information
lessthanoptimal committed Feb 10, 2023
1 parent e1fcbc5 commit 174f998
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@
public class AutomaticSimpleMatrixConvert {
MatrixType commonType;

public void specify0( SimpleBase a, SimpleBase... inputs ) {
SimpleBase array[] = new SimpleBase[inputs.length + 1];
public void specify0( ConstMatrix<?> a, ConstMatrix<?>... inputs ) {
var array = new ConstMatrix[inputs.length + 1];
System.arraycopy(inputs, 0, array, 0, inputs.length);
array[inputs.length] = a;
specify(array);
}

public void specify( SimpleBase... inputs ) {
public void specify( ConstMatrix<?>... inputs ) {
boolean dense = false;
boolean real = true;
int bits = 32;

for (SimpleBase s : inputs) {
MatrixType t = s.mat.getType();
for (ConstMatrix<?> s : inputs) {
MatrixType t = s.getType();
if (t.isDense())
dense = true;
if (!t.isReal())
Expand All @@ -55,7 +55,7 @@ public void specify( SimpleBase... inputs ) {
commonType = MatrixType.lookup(dense, real, bits);
}

public <T extends SimpleBase<T>> T convert( SimpleBase matrix ) {
public <T extends SimpleBase<T>> T convert( SimpleBase<?> matrix ) {
if (matrix.getType() == commonType)
return (T)matrix;

Expand Down
29 changes: 26 additions & 3 deletions main/ejml-simple/src/org/ejml/simple/ConstMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
* local copy that can't be accessed externally.
* </p>
*
* <p>NOTE: Implementations of ConstMatrix must extend {@link SimpleBase} or else it won't work when given as
* input to any class based off of {@link SimpleBase}.</p>
*
* @author Peter Abeles
*/
public interface ConstMatrix<T extends ConstMatrix<T>> {
Expand Down Expand Up @@ -65,7 +68,7 @@ public interface ConstMatrix<T extends ConstMatrix<T>> {
* where c is the returned matrix, a is this matrix, and b is the passed in matrix.
* </p>
*
* @param B A matrix that is n by bn. Not modified.
* @param B A matrix that is n by p. Not modified.
* @return The results of this operation.
* @see CommonOps_DDRM#mult(DMatrix1Row, DMatrix1Row, DMatrix1Row)
*/
Expand Down Expand Up @@ -98,7 +101,6 @@ public interface ConstMatrix<T extends ConstMatrix<T>> {
*/
T plus( ConstMatrix<?> B );


/**
* <p>
* Returns the result of matrix subtraction:<br>
Expand Down Expand Up @@ -141,7 +143,6 @@ public interface ConstMatrix<T extends ConstMatrix<T>> {
*/
T minusComplex( double real, double imag );


/**
* <p>
* Returns the result of scalar addition:<br>
Expand Down Expand Up @@ -767,6 +768,28 @@ default int getNumElements() {
*/
T cols( int begin, int end );

/**
* <p>Concatenates all the matrices together along their columns. If the rows do not match the upper elements
* are set to zero.</p>
*
* A = [ this, m[0] , ... , m[n-1] ]
*
* @param matrices Set of matrices
* @return Resulting matrix
*/
T concatColumns( ConstMatrix<?>... matrices );

/**
* <p>Concatenates all the matrices together along their columns. If the rows do not match the upper elements
* are set to zero.</p>
*
* A = [ this; m[0] ; ... ; m[n-1] ]
*
* @param matrices Set of matrices
* @return Resulting matrix
*/
T concatRows( ConstMatrix<?>... matrices );

/**
* Returns the type of matrix it is wrapping.
*/
Expand Down
60 changes: 23 additions & 37 deletions main/ejml-simple/src/org/ejml/simple/SimpleBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) {
}

/** {@inheritDoc} */
@Override public T plus( ConstMatrix _B ) {
@Override public T plus( ConstMatrix<?> _B ) {
T B = (T)_B;
convertType.specify(this, B);
T A = convertType.convert(this);
Expand Down Expand Up @@ -279,7 +279,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) {
}

/** {@inheritDoc} */
@Override public T plus( double beta, ConstMatrix _B ) {
@Override public T plus( double beta, ConstMatrix<?> _B ) {
T B = (T)_B;
convertType.specify(this, B);
T A = convertType.convert(this);
Expand All @@ -291,7 +291,7 @@ protected static SimpleOperations lookupOps( MatrixType type ) {
}

/** {@inheritDoc} */
@Override public double dot( ConstMatrix _v ) {
@Override public double dot( ConstMatrix<?> _v ) {
T v = (T)_v;
convertType.specify(this, v);
T A = convertType.convert(this);
Expand Down Expand Up @@ -581,7 +581,7 @@ public void setRow( int row, int startColumn, double... values ) {
* @param row Row in 'this'
* @param src Vector which is to be copied into the row
*/
public void setRow( int row, SimpleMatrix src ) {
public void setRow( int row, ConstMatrix<?> src ) {
if (!src.isVector())
throw new IllegalArgumentException("Input matrix must be a vector");
if (src.getNumElements() != numCols())
Expand All @@ -596,9 +596,10 @@ public void setRow( int row, SimpleMatrix src ) {
}

// See if it's a row or column vector and grab the appropriate elements.
double[] vector = src.numRows() < src.numCols() ?
src.ops.getRow(src.mat, 0, 0, src.getNumElements()) :
src.ops.getColumn(src.mat, 0, 0, src.getNumElements());
var bsrc = (SimpleBase)src;
double[] vector = src.getNumRows() < src.getNumCols() ?
bsrc.ops.getRow(bsrc.mat, 0, 0, src.getNumElements()) :
bsrc.ops.getColumn(bsrc.mat, 0, 0, src.getNumElements());

// If src is real but output is complex, convert the vector.
if (src.getType().isReal() && !getType().isReal()) {
Expand Down Expand Up @@ -631,7 +632,7 @@ public void setColumn( int column, int startRow, double... values ) {
* @param column Column in 'this'
* @param src Vector which is to be copied into the column
*/
public void setColumn( int column, SimpleMatrix src ) {
public void setColumn( int column, ConstMatrix<?> src ) {
if (!src.isVector())
throw new IllegalArgumentException("Input matrix must be a vector");
if (src.getNumElements() != numRows())
Expand All @@ -646,9 +647,10 @@ public void setColumn( int column, SimpleMatrix src ) {
}

// See if it's a row or column vector and grab the appropriate elements.
double[] vector = src.numRows() < src.numCols() ?
src.ops.getRow(src.mat, 0, 0, src.getNumElements()) :
src.ops.getColumn(src.mat, 0, 0, src.getNumElements());
var bsrc = (SimpleBase)src;
double[] vector = src.getNumRows() < src.getNumCols() ?
bsrc.ops.getRow(bsrc.mat, 0, 0, src.getNumElements()) :
bsrc.ops.getColumn(bsrc.mat, 0, 0, src.getNumElements());

// If src is real but output is complex, convert the vector.
if (src.getType().isReal() && !getType().isReal()) {
Expand Down Expand Up @@ -1217,32 +1219,24 @@ public void printDimensions() {
return mat.getType().getBits();
}

/**
* <p>Concatenates all the matrices together along their columns. If the rows do not match the upper elements
* are set to zero.</p>
*
* A = [ this, m[0] , ... , m[n-1] ]
*
* @param matrices Set of matrices
* @return Resulting matrix
*/
public T concatColumns( SimpleBase<?>... matrices ) {
/** {@inheritDoc} */
@Override public T concatColumns( ConstMatrix<?>... matrices ) {
convertType.specify0(this, matrices);
T A = convertType.convert(this);

int numCols = A.numCols();
int numRows = A.numRows();
for (int i = 0; i < matrices.length; i++) {
numRows = Math.max(numRows, matrices[i].numRows());
numCols += matrices[i].numCols();
numRows = Math.max(numRows, matrices[i].getNumRows());
numCols += matrices[i].getNumCols();
}

SimpleMatrix combined = SimpleMatrix.wrap(convertType.commonType.create(numRows, numCols));

A.ops.extract(A.mat, 0, A.numRows(), 0, A.numCols(), combined.mat, 0, 0);
int col = A.numCols();
for (int i = 0; i < matrices.length; i++) {
Matrix m = convertType.convert(matrices[i]).mat;
Matrix m = convertType.convert((SimpleBase)matrices[i]).mat;
int cols = m.getNumCols();
int rows = m.getNumRows();
A.ops.extract(m, 0, rows, 0, cols, combined.mat, 0, col);
Expand All @@ -1252,32 +1246,24 @@ public T concatColumns( SimpleBase<?>... matrices ) {
return (T)combined;
}

/**
* <p>Concatenates all the matrices together along their columns. If the rows do not match the upper elements
* are set to zero.</p>
*
* A = [ this; m[0] ; ... ; m[n-1] ]
*
* @param matrices Set of matrices
* @return Resulting matrix
*/
public T concatRows( SimpleBase<?>... matrices ) {
/** {@inheritDoc} */
@Override public T concatRows( ConstMatrix<?>... matrices ) {
convertType.specify0(this, matrices);
T A = convertType.convert(this);

int numCols = A.numCols();
int numRows = A.numRows();
for (int i = 0; i < matrices.length; i++) {
numRows += matrices[i].numRows();
numCols = Math.max(numCols, matrices[i].numCols());
numRows += matrices[i].getNumRows();
numCols = Math.max(numCols, matrices[i].getNumCols());
}

SimpleMatrix combined = SimpleMatrix.wrap(convertType.commonType.create(numRows, numCols));

A.ops.extract(A.mat, 0, A.numRows(), 0, A.numCols(), combined.mat, 0, 0);
int row = A.numRows();
for (int i = 0; i < matrices.length; i++) {
Matrix m = convertType.convert(matrices[i]).mat;
Matrix m = convertType.convert((SimpleBase)matrices[i]).mat;
int cols = m.getNumCols();
int rows = m.getNumRows();
A.ops.extract(m, 0, rows, 0, cols, combined.mat, row, 0);
Expand Down

0 comments on commit 174f998

Please sign in to comment.