-
Notifications
You must be signed in to change notification settings - Fork 63
数据结构
Birdy edited this page Aug 14, 2019
·
11 revisions
- 迭代器(MathIterator)
- 元素
- 标量迭代器(ScalarIterator)
- 单元迭代器(CellIterator)
- 单元(Cell)
- 表单(MathTable)
元素代表迭代器可访问的数量,每个元素代表N个标量或者N个单元.(N>=1)
JStarCraft AI框架支持三种类型的标量
- 默认标量(DefaultScalar)
- 向量标量(VectorScalar)
- 矩阵标量(MatrixScalar)
JStarCraft AI框架支持八种类型的向量
- 稀疏型
- 数组向量(ArrayVector)
- 哈希向量(HashVector)
- 稀疏向量(SparseVector)
- 稠密型
- 稠密向量(DenseVector)
- ND4J向量(Nd4jVector)
- 对称向量(SymmetryVector)
- 包装型
- 整体向量(GlobalVector)
- 局部向量(LocalVector)
JStarCraft AI框架支持八种类型的矩阵
- 稠密型
- 稠密矩阵(DenseMatrix)
- ND4J矩阵(Nd4jMatrix)
- 对称矩阵(SymmetryMatrix)
- 稀疏型
- 数组矩阵(ArrayMatrix)
- 哈希矩阵(HashMatrix)
- 稀疏矩阵(SparseMatrix)
- 包装型
- 整体矩阵(GlobalMatrix)
- 局部矩阵(LocalMatrix)
- 获取/设置/缩放与偏移标量
MathMatrix dataMatrix = getRandomMatrix(dimension);
float oldSum = dataMatrix.getSum(false);
// 缩放所有的元素
dataMatrix.scaleValues(2F);
float newSum = dataMatrix.getSum(false);
Assert.assertThat(newSum, CoreMatchers.equalTo(oldSum * 2F));
oldSum = newSum;
// 偏移所有的元素
dataMatrix.shiftValues(1F);
newSum = dataMatrix.getSum(false);
Assert.assertThat(newSum, CoreMatchers.equalTo(oldSum + dataMatrix.getElementSize()));
// 设置所有的元素
dataMatrix.setValues(0F);
newSum = dataMatrix.getSum(false);
Assert.assertThat(newSum, CoreMatchers.equalTo(0F));
- 四则运算与拷贝
int dimension = 10;
MathMatrix dataMatrix = getZeroMatrix(dimension);
// 迭代矩阵所有的元素
dataMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> {
scalar.setValue(RandomUtility.randomFloat(10F));
});
MathMatrix copyMatrix = getZeroMatrix(dimension);
float sum = dataMatrix.getSum(false);
// 矩阵与矩阵拷贝操作
copyMatrix.copyMatrix(dataMatrix, false);
Assert.assertThat(copyMatrix.getSum(false), CoreMatchers.equalTo(sum));
// 矩阵与矩阵减法运算
dataMatrix.subtractMatrix(copyMatrix, false);
Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(0F));
// 矩阵与矩阵加法运算
dataMatrix.addMatrix(copyMatrix, false);
Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(sum));
// 矩阵与矩阵除法运算
dataMatrix.divideMatrix(copyMatrix, false);
Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(dataMatrix.getElementSize() + 0F));
// 矩阵与矩阵乘法运算
dataMatrix.multiplyMatrix(copyMatrix, false);
Assert.assertThat(dataMatrix.getSum(false), CoreMatchers.equalTo(sum));
- 点积与累积(支持串行与并行计算)
int dimension = 10;
MathMatrix leftMatrix = getRandomMatrix(dimension);
MathMatrix rightMatrix = getRandomMatrix(dimension);
MathMatrix dataMatrix = getZeroMatrix(dimension);
MathMatrix markMatrix = DenseMatrix.valueOf(dimension, dimension);
MathVector dataVector = dataMatrix.getRowVector(0);
MathVector markVector = markMatrix.getRowVector(0);
// 矩阵与矩阵的点积运算(串行)
markMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.SERIAL);
Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
// 矩阵与矩阵的点积运算(并行)
dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
MathVector leftVector = leftMatrix.getRowVector(RandomUtility.randomInteger(dimension));
MathVector rightVector = rightMatrix.getRowVector(RandomUtility.randomInteger(dimension));
// 向量与向量的点积运算(串行)
markMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.SERIAL);
Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
// 向量与向量的点积运算(并行)
dataMatrix.dotProduct(leftVector, rightVector, MathCalculator.PARALLEL);
Assert.assertTrue(equalMatrix(dataMatrix, markMatrix));
// 矩阵与向量的点积运算(串行)
markVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.SERIAL);
Assert.assertTrue(equalVector(dataVector, markVector));
// 矩阵与向量的点积运算(并行)
dataVector.dotProduct(leftMatrix, false, rightVector, MathCalculator.PARALLEL);
Assert.assertTrue(equalVector(dataVector, markVector));
// 向量与矩阵的点积运算(串行)
markVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.SERIAL);
dataVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.SERIAL);
Assert.assertTrue(equalVector(dataVector, markVector));
// 向量与矩阵的点积运算(并行)
dataVector.dotProduct(leftVector, rightMatrix, false, MathCalculator.PARALLEL);
Assert.assertTrue(equalVector(dataVector, markVector));