Skip to content

Commit a8ca5fc

Browse files
committed
Add ResizeVector, PadColumns and PadRows for mat32 and mat64 packages
1 parent 2eaa22a commit a8ca5fc

File tree

4 files changed

+415
-2
lines changed

4 files changed

+415
-2
lines changed

pkg/mat32/dense.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,71 @@ func (d *Dense) DoNonZero(fn func(i, j int, v Float)) {
913913
}
914914
}
915915

916+
// ResizeVector returns a resized copy of the given vector x.
917+
//
918+
// The input x MUST be a vector.
919+
//
920+
// If the new size is smaller than the input vector, the remaining tail
921+
// elements are removed. If it's bigger, the additional tail elements
922+
// will are set to zero.
923+
func (d *Dense) ResizeVector(newSize int) Matrix {
924+
xSize := d.Size()
925+
xData := d.data
926+
if newSize <= xSize {
927+
return NewVecDense(xData[:newSize])
928+
}
929+
930+
y := NewEmptyVecDense(newSize)
931+
yData := y.data
932+
copy(yData[:xSize], xData)
933+
copy(y.data, yData)
934+
935+
return y
936+
}
937+
938+
// PadColumns returns a copy of the given matrix x with n additional tail columns.
939+
// The additional elements are set to zero.
940+
func (d *Dense) PadColumns(n int) Matrix {
941+
rows := d.rows
942+
xCols := d.cols
943+
yCols := xCols + n
944+
y := NewEmptyDense(rows, yCols)
945+
946+
if rows == 0 || xCols == 0 {
947+
return y
948+
}
949+
950+
xData := d.data
951+
yData := y.data
952+
for r, xi, yi := 0, 0, 0; r < rows; r, xi, yi = r+1, xi+xCols, yi+yCols {
953+
copy(yData[yi:yi+xCols], xData[xi:xi+xCols])
954+
}
955+
copy(y.data, yData)
956+
957+
return y
958+
}
959+
960+
// PadRows returns a copy of the given matrix x with n additional tail rows.
961+
// The additional elements are set to zero.
962+
func (d *Dense) PadRows(n int) Matrix {
963+
cols := d.cols
964+
xRows := d.rows
965+
yRows := xRows + n
966+
967+
y := NewEmptyDense(yRows, cols)
968+
969+
if cols == 0 || xRows == 0 {
970+
return y
971+
}
972+
973+
xData := d.data
974+
yData := y.data
975+
copy(yData[:len(xData)], xData)
976+
copy(y.data, yData)
977+
978+
return y
979+
}
980+
916981
// String returns a string representation of the matrix data.
917982
func (d *Dense) String() string {
918983
return fmt.Sprintf("%v", d.data)

pkg/mat32/dense_test.go

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
package mat32
66

77
import (
8-
"github.com/stretchr/testify/assert"
8+
"fmt"
9+
"reflect"
910
"testing"
11+
12+
"github.com/stretchr/testify/assert"
1013
)
1114

1215
func TestDense_AddScalar(t *testing.T) {
@@ -1277,6 +1280,141 @@ func TestDense_DoNonZero(t *testing.T) {
12771280
})
12781281
}
12791282

1283+
func TestResizeVector(t *testing.T) {
1284+
t.Parallel()
1285+
1286+
cases := []struct {
1287+
x []Float
1288+
size int
1289+
expected []Float
1290+
}{
1291+
{[]Float{}, 0, []Float{}},
1292+
{[]Float{}, 1, []Float{0}},
1293+
{[]Float{}, 2, []Float{0, 0}},
1294+
1295+
{[]Float{1}, 0, []Float{}},
1296+
{[]Float{1}, 1, []Float{1}},
1297+
{[]Float{1}, 2, []Float{1, 0}},
1298+
{[]Float{1}, 3, []Float{1, 0, 0}},
1299+
1300+
{[]Float{1, 2}, 0, []Float{}},
1301+
{[]Float{1, 2}, 1, []Float{1}},
1302+
{[]Float{1, 2}, 2, []Float{1, 2}},
1303+
{[]Float{1, 2}, 3, []Float{1, 2, 0}},
1304+
{[]Float{1, 2}, 4, []Float{1, 2, 0, 0}},
1305+
}
1306+
1307+
for _, c := range cases {
1308+
testName := fmt.Sprintf("ResizeVector(%#v, %d) == %#v", c.x, c.size, c.expected)
1309+
t.Run(testName, func(t *testing.T) {
1310+
x := NewVecDense(c.x)
1311+
y := x.ResizeVector(c.size)
1312+
if !reflect.DeepEqual(y.Data(), c.expected) {
1313+
t.Fatalf("expected %#v, actual %#v", c.expected, y.Data())
1314+
}
1315+
if !reflect.DeepEqual(x.Data(), c.x) {
1316+
t.Fatalf("the input vector was modified: %#v", x.Data())
1317+
}
1318+
})
1319+
}
1320+
}
1321+
1322+
func TestPadColumns(t *testing.T) {
1323+
t.Parallel()
1324+
1325+
cases := []struct {
1326+
x Matrix
1327+
cols int
1328+
expected Matrix
1329+
}{
1330+
{NewDense(0, 0, []Float{}), 0, NewDense(0, 0, []Float{})},
1331+
{NewDense(0, 0, []Float{}), 1, NewDense(0, 1, []Float{})},
1332+
{NewDense(0, 0, []Float{}), 2, NewDense(0, 2, []Float{})},
1333+
1334+
{NewDense(0, 1, []Float{}), 0, NewDense(0, 1, []Float{})},
1335+
{NewDense(0, 1, []Float{}), 1, NewDense(0, 2, []Float{})},
1336+
{NewDense(0, 1, []Float{}), 2, NewDense(0, 3, []Float{})},
1337+
1338+
{NewDense(1, 0, []Float{}), 0, NewDense(1, 0, []Float{})},
1339+
{NewDense(1, 0, []Float{}), 1, NewDense(1, 1, []Float{0})},
1340+
{NewDense(1, 0, []Float{}), 2, NewDense(1, 2, []Float{0, 0})},
1341+
1342+
{NewDense(1, 1, []Float{1}), 0, NewDense(1, 1, []Float{1})},
1343+
{NewDense(1, 1, []Float{1}), 1, NewDense(1, 2, []Float{1, 0})},
1344+
{NewDense(1, 1, []Float{1}), 2, NewDense(1, 3, []Float{1, 0, 0})},
1345+
1346+
{NewDense(2, 1, []Float{1, 2}), 0, NewDense(2, 1, []Float{1, 2})},
1347+
{NewDense(2, 1, []Float{1, 2}), 1, NewDense(2, 2, []Float{1, 0, 2, 0})},
1348+
{NewDense(2, 1, []Float{1, 2}), 2, NewDense(2, 3, []Float{1, 0, 0, 2, 0, 0})},
1349+
1350+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 0, NewDense(2, 2, []Float{1, 2, 3, 4})},
1351+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 1, NewDense(2, 3, []Float{1, 2, 0, 3, 4, 0})},
1352+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 2, NewDense(2, 4, []Float{1, 2, 0, 0, 3, 4, 0, 0})},
1353+
}
1354+
1355+
for _, c := range cases {
1356+
testName := fmt.Sprintf("PadColumns(%.0f, %d) == %.0f", c.x, c.cols, c.expected)
1357+
t.Run(testName, func(t *testing.T) {
1358+
xCopy := c.x.Clone()
1359+
y := c.x.(*Dense).PadColumns(c.cols)
1360+
if !matrixEqual(y, c.expected) {
1361+
t.Fatalf("expected %.0f, actual %.0f", c.expected, y)
1362+
}
1363+
if !matrixEqual(c.x, xCopy) {
1364+
t.Fatalf("the input vector was modified: %.0f", c.x)
1365+
}
1366+
})
1367+
}
1368+
}
1369+
1370+
func TestPadRows(t *testing.T) {
1371+
t.Parallel()
1372+
1373+
cases := []struct {
1374+
x Matrix
1375+
cols int
1376+
expected Matrix
1377+
}{
1378+
{NewDense(0, 0, []Float{}), 0, NewDense(0, 0, []Float{})},
1379+
{NewDense(0, 0, []Float{}), 1, NewDense(1, 0, []Float{})},
1380+
{NewDense(0, 0, []Float{}), 2, NewDense(2, 0, []Float{})},
1381+
1382+
{NewDense(1, 0, []Float{}), 0, NewDense(1, 0, []Float{})},
1383+
{NewDense(1, 0, []Float{}), 1, NewDense(2, 0, []Float{})},
1384+
{NewDense(1, 0, []Float{}), 2, NewDense(3, 0, []Float{})},
1385+
1386+
{NewDense(0, 1, []Float{}), 0, NewDense(0, 1, []Float{})},
1387+
{NewDense(0, 1, []Float{}), 1, NewDense(1, 1, []Float{0})},
1388+
{NewDense(0, 1, []Float{}), 2, NewDense(2, 1, []Float{0, 0})},
1389+
1390+
{NewDense(1, 1, []Float{1}), 0, NewDense(1, 1, []Float{1})},
1391+
{NewDense(1, 1, []Float{1}), 1, NewDense(2, 1, []Float{1, 0})},
1392+
{NewDense(1, 1, []Float{1}), 2, NewDense(3, 1, []Float{1, 0, 0})},
1393+
1394+
{NewDense(1, 2, []Float{1, 2}), 0, NewDense(1, 2, []Float{1, 2})},
1395+
{NewDense(1, 2, []Float{1, 2}), 1, NewDense(2, 2, []Float{1, 2, 0, 0})},
1396+
{NewDense(1, 2, []Float{1, 2}), 2, NewDense(3, 2, []Float{1, 2, 0, 0, 0, 0})},
1397+
1398+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 0, NewDense(2, 2, []Float{1, 2, 3, 4})},
1399+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 1, NewDense(3, 2, []Float{1, 2, 3, 4, 0, 0})},
1400+
{NewDense(2, 2, []Float{1, 2, 3, 4}), 2, NewDense(4, 2, []Float{1, 2, 3, 4, 0, 0, 0, 0})},
1401+
}
1402+
1403+
for _, c := range cases {
1404+
testName := fmt.Sprintf("PadRows(%.0f, %d) == %.0f", c.x, c.cols, c.expected)
1405+
t.Run(testName, func(t *testing.T) {
1406+
xCopy := c.x.Clone()
1407+
y := c.x.(*Dense).PadRows(c.cols)
1408+
if !matrixEqual(y, c.expected) {
1409+
t.Fatalf("expected %.0f, actual %.0f", c.expected, y)
1410+
}
1411+
if !matrixEqual(c.x, xCopy) {
1412+
t.Fatalf("the input vector was modified: %.0f", c.x)
1413+
}
1414+
})
1415+
}
1416+
}
1417+
12801418
func TestDense_String(t *testing.T) {
12811419
d := NewVecDense([]Float{1, 2, 3})
12821420
assert.Equal(t, "[1 2 3]", d.String())
@@ -1291,3 +1429,7 @@ func assertSliceEqualApprox(t *testing.T, expected, actual []Float) {
12911429
t.Helper()
12921430
assert.InDeltaSlice(t, expected, actual, 1.0e-04)
12931431
}
1432+
1433+
func matrixEqual(a, b Matrix) bool {
1434+
return SameDims(a, b) && reflect.DeepEqual(a.Data(), b.Data())
1435+
}

pkg/mat64/dense.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,71 @@ func (d *Dense) DoNonZero(fn func(i, j int, v Float)) {
915915
}
916916
}
917917

918+
// ResizeVector returns a resized copy of the given vector x.
919+
//
920+
// The input x MUST be a vector.
921+
//
922+
// If the new size is smaller than the input vector, the remaining tail
923+
// elements are removed. If it's bigger, the additional tail elements
924+
// will are set to zero.
925+
func (d *Dense) ResizeVector(newSize int) Matrix {
926+
xSize := d.Size()
927+
xData := d.data
928+
if newSize <= xSize {
929+
return NewVecDense(xData[:newSize])
930+
}
931+
932+
y := NewEmptyVecDense(newSize)
933+
yData := y.data
934+
copy(yData[:xSize], xData)
935+
copy(y.data, yData)
936+
937+
return y
938+
}
939+
940+
// PadColumns returns a copy of the given matrix x with n additional tail columns.
941+
// The additional elements are set to zero.
942+
func (d *Dense) PadColumns(n int) Matrix {
943+
rows := d.rows
944+
xCols := d.cols
945+
yCols := xCols + n
946+
y := NewEmptyDense(rows, yCols)
947+
948+
if rows == 0 || xCols == 0 {
949+
return y
950+
}
951+
952+
xData := d.data
953+
yData := y.data
954+
for r, xi, yi := 0, 0, 0; r < rows; r, xi, yi = r+1, xi+xCols, yi+yCols {
955+
copy(yData[yi:yi+xCols], xData[xi:xi+xCols])
956+
}
957+
copy(y.data, yData)
958+
959+
return y
960+
}
961+
962+
// PadRows returns a copy of the given matrix x with n additional tail rows.
963+
// The additional elements are set to zero.
964+
func (d *Dense) PadRows(n int) Matrix {
965+
cols := d.cols
966+
xRows := d.rows
967+
yRows := xRows + n
968+
969+
y := NewEmptyDense(yRows, cols)
970+
971+
if cols == 0 || xRows == 0 {
972+
return y
973+
}
974+
975+
xData := d.data
976+
yData := y.data
977+
copy(yData[:len(xData)], xData)
978+
copy(y.data, yData)
979+
980+
return y
981+
}
982+
918983
// String returns a string representation of the matrix data.
919984
func (d *Dense) String() string {
920985
return fmt.Sprintf("%v", d.data)

0 commit comments

Comments
 (0)