Skip to content

Commit

Permalink
Handle of shape changes in dynamic datasets for max shape and chunks
Browse files Browse the repository at this point in the history
Add unit test
  • Loading branch information
PeterC-DLS committed Jan 21, 2025
1 parent bde924f commit d17a0a5
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

package org.eclipse.january.dataset;

import static org.junit.Assert.assertArrayEquals;

import java.util.Arrays;

import org.eclipse.january.asserts.TestUtils;
Expand All @@ -27,7 +29,7 @@ public void dataChangePerformed(DataEvent evt) {
}
}

IDynamicDataset createDynamic() {
private static IDynamicDataset createDynamic() {
IDynamicDataset lazy = new LazyDynamicDataset(null, "test", 1, IntegerDataset.class, new int[] {0,4}, new int[] {IDynamicDataset.UNLIMITED, 4});
return lazy;
}
Expand Down Expand Up @@ -82,4 +84,36 @@ public void testShapeChecker() {
}
Assert.assertEquals(repeat, counter.count);
}

@Test
public void testShapeChanges() {
IDynamicDataset lazy = new LazyDynamicDataset(null, "test", 1, IntegerDataset.class, new int[] {3, 4}, new int[] {IDynamicDataset.UNLIMITED, 4});
lazy.setChunking(2, 4);

IDynamicDataset t = (IDynamicDataset) lazy.getSliceView(new Slice(2));
assertArrayEquals(new int[] {2, 4}, t.getShape());
assertArrayEquals(new int[] {IDynamicDataset.UNLIMITED, 4}, t.getMaxShape());
assertArrayEquals(new int[] {2, 4}, t.getChunking());

t = (IDynamicDataset) lazy.getTransposedView(1, 0);
assertArrayEquals(new int[] {4, 3}, t.getShape());
assertArrayEquals(new int[] {4, IDynamicDataset.UNLIMITED}, t.getMaxShape());
assertArrayEquals(new int[] {4, 2}, t.getChunking());

t = (IDynamicDataset) lazy.getSliceView();
t.setShape(1, 3, 4);
assertArrayEquals(new int[] {1, 3, 4}, t.getShape());
assertArrayEquals(new int[] {1, IDynamicDataset.UNLIMITED, 4}, t.getMaxShape());
assertArrayEquals(new int[] {1, 2, 4}, t.getChunking());

t = (IDynamicDataset) t.getTransposedView(2, 1, 0);
assertArrayEquals(new int[] {4, 3, 1}, t.getShape());
assertArrayEquals(new int[] {4, IDynamicDataset.UNLIMITED, 1}, t.getMaxShape());
assertArrayEquals(new int[] {4, 2, 1}, t.getChunking());

t = (IDynamicDataset) t.getSliceView(null, new Slice(2), null);
assertArrayEquals(new int[] {4, 2, 1}, t.getShape());
assertArrayEquals(new int[] {4, IDynamicDataset.UNLIMITED, 1}, t.getMaxShape());
assertArrayEquals(new int[] {4, 2, 1}, t.getChunking());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void testDynamicAxesMetadata() throws Exception {
slice.setSlice(0, shape[0]-1, shape[0], 1);
slice.setSlice(1, shape[1]-1, shape[1], 1);
IDataset s = dataset.getSlice(slice);
Assert.assertEquals((long) max, s.max().longValue());
Assert.assertEquals(max, s.max().longValue());
max+=400;

AxesMetadata axm = s.getFirstMetadata(AxesMetadata.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,7 @@ public LazyDataset getSliceView(Slice... slice) {
return internalGetSliceView(new SliceND(shape, slice));
}

/**
* @param nShape
*/
private void setShapeInternal(int... nShape) {
void setShapeInternal(int... nShape) {
// work out transposed (sliced) shape (instead of removing padding from current shape)
if (size != 0) {
int[] pShape = calcTransposed(map, sShape == null ? oShape : sShape);
Expand Down Expand Up @@ -469,6 +466,11 @@ public LazyDataset getTransposedView(final int... axes) {
return view;
}

internalTransposeView(view, naxes, axes);
return view;
}

void internalTransposeView(LazyDataset view, int[] naxes, final int... axes) {
view.shape = calcTransposed(naxes, shape);
if (view.size != 0 && padding != null) { // work out transpose by reverting effect of padding
int or = oShape.length;
Expand Down Expand Up @@ -550,7 +552,6 @@ public LazyDataset getTransposedView(final int... axes) {
}
view.storeMetadata(metadata, Transposable.class);
view.transposeMetadata(axes);
return view;
}

private static int find(int[] map, int m, int off) {
Expand All @@ -562,7 +563,7 @@ private static int find(int[] map, int m, int off) {
return -1;
}

private static int[] calcTransposed(int[] map, int[] values) {
static int[] calcTransposed(int[] map, int[] values) {
if (values == null) {
return null;
}
Expand Down Expand Up @@ -615,8 +616,8 @@ protected final SliceND calcTrueSlice(SliceND slice) {
int r = oShape.length;
if (padding == null) {
nshape = slice.getShape();
nstart = slice.getStart();
nstep = slice.getStep();
nstart = slice.getStart().clone();
nstep = slice.getStep().clone();
} else {
final int[] lshape = slice.getShape();
final int[] lstart = slice.getStart();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,9 @@ public ILazyDataset run(ILazyDataset lz) {
nslice = slice.clone();
for (int i = 0; i < rank; i++) {
int s = shape[i];
if (s >= oShape[i]) {
continue;
} else if (s == 1) {
if (s == 1) {
nslice.setSlice(i, 0, 1, 1);
} else {
} else if (s < oShape[i]) {
throw new IllegalArgumentException("Sliceable dataset has non-unit dimension less than host!");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
public class LazyDynamicDataset extends LazyDataset implements IDynamicDataset {
private static final long serialVersionUID = -6296506563932840938L;

protected int[] oMaxShape;
protected int[] maxShape;
protected int[] chunks;

Expand Down Expand Up @@ -124,6 +125,7 @@ public LazyDynamicDataset(ILazyLoader loader, String name, int elements, Class<?
} else {
this.maxShape = maxShape.clone();
}
this.oMaxShape = this.maxShape;
this.chunks = chunks == null ? null : chunks.clone();

this.eventDelegate = new DataListenerDelegate();
Expand All @@ -137,6 +139,7 @@ protected LazyDynamicDataset(LazyDynamicDataset other) {
super(other);

maxShape = other.maxShape;
oMaxShape = other.oMaxShape;
chunks = other.chunks;
eventDelegate = other.eventDelegate;
checker = other.checker;
Expand All @@ -149,6 +152,7 @@ public int hashCode() {
int result = super.hashCode();
result = prime * result + ((checker == null) ? 0 : checker.hashCode());
result = prime * result + ((checkingThread == null) ? 0 : checkingThread.hashCode());
result = prime * result + Arrays.hashCode(oMaxShape);
result = prime * result + Arrays.hashCode(maxShape);
result = prime * result + Arrays.hashCode(chunks);
return result;
Expand All @@ -164,6 +168,9 @@ public boolean equals(Object obj) {
}

LazyDynamicDataset other = (LazyDynamicDataset) obj;
if (!Arrays.equals(oMaxShape, other.oMaxShape)) {
return false;
}
if (!Arrays.equals(maxShape, other.maxShape)) {
return false;
}
Expand Down Expand Up @@ -232,7 +239,7 @@ public boolean resize(int... newShape) {
if (maxShape != null) {
for (int i = 0; i < rank; i++) {
int m = maxShape[i];
if (m != -1 && newShape[i] > m) {
if (m != UNLIMITED && newShape[i] > m) {
throw new IllegalArgumentException("A dimension of new shape must not exceed maximum shape");
}
}
Expand All @@ -257,6 +264,7 @@ public int[] getMaxShape() {
@Override
public void setMaxShape(int... maxShape) {
this.maxShape = maxShape == null ? shape.clone() : maxShape.clone();
this.oMaxShape = this.maxShape;

if (this.maxShape.length > oShape.length) {
oShape = prependShapeWithOnes(this.maxShape.length, oShape);
Expand All @@ -267,6 +275,36 @@ public void setMaxShape(int... maxShape) {
}
}

@Override
void setShapeInternal(int... nShape) {
super.setShapeInternal(nShape);
int r = shape.length;
if (maxShape != null) {
maxShape = ShapeUtils.padShape(padding, r, oMaxShape);
}
if (chunks != null) {
chunks = ShapeUtils.padShape(padding, r, chunks);
}
}

/**
* @since 2.3
*/
@Override
public LazyDynamicDataset getTransposedView(int... axes) {
LazyDynamicDataset view = clone();

int[] naxes = checkPermutatedAxes(shape, axes);
if (naxes == null) {
return view;
}

internalTransposeView(view, naxes, axes);
view.maxShape = calcTransposed(naxes, maxShape);
view.chunks = calcTransposed(naxes, chunks);
return view;
}

@Override
public int[] getChunking() {
return chunks;
Expand Down Expand Up @@ -296,7 +334,7 @@ protected void checkSliceND(SliceND slice) {

@Override
protected SliceND createSlice(int[] nstart, int[] nstop, int[] nstep) {
return SliceND.createSlice(oShape, maxShape, nstart, nstop, nstep);
return SliceND.createSlice(oShape, oMaxShape, nstart, nstop, nstep);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ public LazyWriteableDataset(String name, Class<?> eClass, int[] shape, int[] max
protected LazyWriteableDataset(LazyWriteableDataset other) {
super(other);

chunks = other.chunks;
saver = other.saver;
fillValue = other.fillValue;
writeAsync = other.writeAsync;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ public static int[] refreshDynamicAxesMetadata(List<AxesMetadata> axm, int[] sha
if (axm == null) return maxShape;

for (AxesMetadata a : axm) {
AxesMetadata ai = a;
int[] s = ai.refresh(shape);
int[] s = a.refresh(shape);
for (int i = 0; i < s.length; i++) {
if (maxShape[i] > s[i]) maxShape[i] = s[i];
}
}

return maxShape;
return maxShape;
}

}

0 comments on commit d17a0a5

Please sign in to comment.