Skip to content

Commit

Permalink
correct bug wrt numpy format in shm arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 30, 2024
1 parent 2e4dafd commit 2e3f763
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ private SharedMemoryArrayLinux(String name, int size, String dtype, long[] shape
INSTANCE_C.close(shmFd);
throw new RuntimeException("mmap failed, errno: " + Native.getLastError());
}

if (!alreadyExists && this.isNumpyFormat) {
byte[] header = getNpyHeader(dtype, shape, this.isFortran);
long offset = 0;
for (byte b : header) {
this.pSharedMemory.setByte(offset ++, b);
}
}
}

/**
Expand Down Expand Up @@ -670,6 +678,35 @@ byte[] getNpyHeader(RandomAccessibleInterval<T> tensor, boolean fortranOrder) {
return total;
}

@SuppressWarnings("unchecked")
private static <T extends RealType<T> & NativeType<T>>
byte[] getNpyHeader(String dtype, long[] shape, boolean fortranOrder) {
String strHeader = "{'descr': '<";
strHeader += DecodeNumpy.getDataType((T) CommonUtils.getImgLib2DataType(dtype));
strHeader += "', 'fortran_order': " + (fortranOrder ? "True" : "False") + ", 'shape': (";
for (long ll : shape) strHeader += ll + ", ";
strHeader = strHeader.substring(0, strHeader.length() - 2);
strHeader += "), }" + System.lineSeparator();
byte[] bufInverse = strHeader.getBytes(StandardCharsets.UTF_8);
byte[] major = {1};
byte[] minor = {0};
byte[] len = new byte[2];
len[0] = (byte) (short) strHeader.length();
len[1] = (byte) (((short) strHeader.length()) >> 8);
int totalLen = DecodeNumpy.NUMPY_PREFIX.length + 2 + 2 + bufInverse.length;
byte[] total = new byte[totalLen];
int c = 0;
for (int i = 0; i < DecodeNumpy.NUMPY_PREFIX.length; i ++)
total[c ++] = DecodeNumpy.NUMPY_PREFIX[i];
total[c ++] = major[0];
total[c ++] = minor[0];
total[c ++] = len[0];
total[c ++] = len[1];
for (int i = 0; i < bufInverse.length; i ++)
total[c ++] = bufInverse[i];
return total;
}

/**
* {@inheritDoc}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ protected SharedMemoryArrayMacOS(String name, int size, String dtype, long[] sha
INSTANCE.close(shmFd);
throw new RuntimeException("mmap failed, errno: " + Native.getLastError());
}



if (!alreadyExists && this.isNumpyFormat) {
byte[] header = getNpyHeader(dtype, shape, this.isFortran);
long offset = 0;
for (byte b : header) {
this.pSharedMemory.setByte(offset ++, b);
}
}
}

/**
Expand Down Expand Up @@ -597,6 +607,35 @@ private void buildFloat64(RandomAccessibleInterval<DoubleType> tensor, boolean i
}
}

@SuppressWarnings("unchecked")
private static <T extends RealType<T> & NativeType<T>>
byte[] getNpyHeader(String dtype, long[] shape, boolean fortranOrder) {
String strHeader = "{'descr': '<";
strHeader += DecodeNumpy.getDataType((T) CommonUtils.getImgLib2DataType(dtype));
strHeader += "', 'fortran_order': " + (fortranOrder ? "True" : "False") + ", 'shape': (";
for (long ll : shape) strHeader += ll + ", ";
strHeader = strHeader.substring(0, strHeader.length() - 2);
strHeader += "), }" + System.lineSeparator();
byte[] bufInverse = strHeader.getBytes(StandardCharsets.UTF_8);
byte[] major = {1};
byte[] minor = {0};
byte[] len = new byte[2];
len[0] = (byte) (short) strHeader.length();
len[1] = (byte) (((short) strHeader.length()) >> 8);
int totalLen = DecodeNumpy.NUMPY_PREFIX.length + 2 + 2 + bufInverse.length;
byte[] total = new byte[totalLen];
int c = 0;
for (int i = 0; i < DecodeNumpy.NUMPY_PREFIX.length; i ++)
total[c ++] = DecodeNumpy.NUMPY_PREFIX[i];
total[c ++] = major[0];
total[c ++] = minor[0];
total[c ++] = len[0];
total[c ++] = len[1];
for (int i = 0; i < bufInverse.length; i ++)
total[c ++] = bufInverse[i];
return total;
}

private static <T extends RealType<T> & NativeType<T>>
byte[] getNpyHeader(RandomAccessibleInterval<T> tensor, boolean isFortran) {
String strHeader = "{'descr': '<";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ protected SharedMemoryArrayWin(String name, int size, String dtype, long[] shape
this.isFortran = isFortran;
int flag = WinNT.PAGE_READWRITE;
boolean write = true;
if (checkSHMExists(memoryName)) {
boolean alreadyExists = checkSHMExists(memoryName);
if (alreadyExists) {
long prevSize = getSHMSize(name);
if (prevSize != 0 && prevSize != DEFAULT_RESERVED_MEMORY && prevSize < size)
throw new FileAlreadyExistsException("Shared memory segment already exists with different dimensions, data type or format. "
Expand Down Expand Up @@ -271,6 +272,16 @@ protected SharedMemoryArrayWin(String name, int size, String dtype, long[] shape
+ "" + Kernel32.INSTANCE.GetLastError());
}
}



if (!alreadyExists && this.isNumpyFormat) {
byte[] header = getNpyHeader(dtype, shape, this.isFortran);
long offset = 0;
for (byte b : header) {
this.mappedPointer.setByte(offset, b);
}
}
}

private boolean checkSHMExists(String memoryName) {
Expand Down Expand Up @@ -486,6 +497,35 @@ private void addByteArray(byte[] arr) {
}
}

@SuppressWarnings("unchecked")
private static <T extends RealType<T> & NativeType<T>>
byte[] getNpyHeader(String dtype, long[] shape, boolean fortranOrder) {
String strHeader = "{'descr': '<";
strHeader += DecodeNumpy.getDataType((T) CommonUtils.getImgLib2DataType(dtype));
strHeader += "', 'fortran_order': " + (fortranOrder ? "True" : "False") + ", 'shape': (";
for (long ll : shape) strHeader += ll + ", ";
strHeader = strHeader.substring(0, strHeader.length() - 2);
strHeader += "), }" + System.lineSeparator();
byte[] bufInverse = strHeader.getBytes(StandardCharsets.UTF_8);
byte[] major = {1};
byte[] minor = {0};
byte[] len = new byte[2];
len[0] = (byte) (short) strHeader.length();
len[1] = (byte) (((short) strHeader.length()) >> 8);
int totalLen = DecodeNumpy.NUMPY_PREFIX.length + 2 + 2 + bufInverse.length;
byte[] total = new byte[totalLen];
int c = 0;
for (int i = 0; i < DecodeNumpy.NUMPY_PREFIX.length; i ++)
total[c ++] = DecodeNumpy.NUMPY_PREFIX[i];
total[c ++] = major[0];
total[c ++] = minor[0];
total[c ++] = len[0];
total[c ++] = len[1];
for (int i = 0; i < bufInverse.length; i ++)
total[c ++] = bufInverse[i];
return total;
}

private static <T extends RealType<T> & NativeType<T>>
byte[] getNpyHeader(RandomAccessibleInterval<T> tensor, boolean isFortran) {
String strHeader = "{'descr': '<";
Expand Down

0 comments on commit 2e3f763

Please sign in to comment.