Skip to content

Commit

Permalink
keep improving the robsutness of shm creation
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Mar 13, 2024
1 parent a626bf3 commit 3ed15ff
Showing 1 changed file with 50 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,23 @@
*/
public class SharedMemoryArrayLinux implements SharedMemoryArray
{
/**
* Instance of the LibRT JNI containing the methods to interact with the Shared memory segments
*/
private static final LibRt INSTANCE_RT = LibRt.INSTANCE;
/**
* Instance of the CLibrary JNI containing the methods to interact with the Shared memory segments
*/
private static final LibRt INSTANCE = LibRt.INSTANCE;
private static final CLibrary INSTANCE_C = CLibrary.INSTANCE;
/**
* Depending on the computer, some might work with LibRT or LibC to create SHM segments.
* Thus if true use librt if false, use libc instance
*/
private boolean useLibRT = true;
/**
* File descriptor value of the shared memory segment
*/
private final int shmFd;
private int shmFd;
/**
* Pointer referencing the shared memory byte array
*/
Expand Down Expand Up @@ -148,20 +157,33 @@ protected SharedMemoryArrayLinux(String name, int size, String dtype, long[] sha
this.originalDims = shape;
this.size = size;
this.memoryName = name;

shmFd = INSTANCE.shm_open(this.memoryName, O_RDWR | O_CREAT, 0700);
try {
shmFd = INSTANCE_RT.shm_open(this.memoryName, O_RDWR | O_CREAT, 0700);
} catch (Exception ex) {
this.useLibRT = false;
shmFd = INSTANCE_C.shm_open(this.memoryName, O_RDWR | O_CREAT, 0700);
}
if (shmFd < 0) {
throw new RuntimeException("shm_open failed, errno: " + Native.getLastError());
}

if (INSTANCE.ftruncate(shmFd, this.size) == -1) {
INSTANCE.close(shmFd);
if (this.useLibRT && INSTANCE_RT.ftruncate(shmFd, this.size) == -1) {
INSTANCE_RT.close(shmFd);
throw new RuntimeException("ftruncate failed, errno: " + Native.getLastError());
} else if (!this.useLibRT && INSTANCE_C.ftruncate(shmFd, this.size) == -1) {
INSTANCE_C.close(shmFd);
throw new RuntimeException("ftruncate failed, errno: " + Native.getLastError());
}

pSharedMemory = INSTANCE.mmap(Pointer.NULL, this.size, PROT_READ | PROT_WRITE, MAP_SHARED, shmFd, 0);
if (pSharedMemory == Pointer.NULL) {
INSTANCE.close(shmFd);
if (this.useLibRT)
pSharedMemory = INSTANCE_RT.mmap(Pointer.NULL, this.size, PROT_READ | PROT_WRITE, MAP_SHARED, shmFd, 0);
else
pSharedMemory = INSTANCE_C.mmap(Pointer.NULL, this.size, PROT_READ | PROT_WRITE, MAP_SHARED, shmFd, 0);

if (this.useLibRT && pSharedMemory == Pointer.NULL) {
INSTANCE_RT.close(shmFd);
throw new RuntimeException("mmap failed, errno: " + Native.getLastError());
} else if (!this.useLibRT && pSharedMemory == Pointer.NULL) {
INSTANCE_C.close(shmFd);
throw new RuntimeException("mmap failed, errno: " + Native.getLastError());
}
}
Expand Down Expand Up @@ -459,25 +481,32 @@ private void buildFloat64(RandomAccessibleInterval<DoubleType> tensor)
*/
public void close() {
if (this.unlinked) return;
int checkhmFd;
if (this.useLibRT) checkhmFd = INSTANCE_RT.shm_open(this.memoryName, O_RDONLY, 0700);
else checkhmFd = INSTANCE_C.shm_open(this.memoryName, O_RDONLY, 0700);

int checkhmFd = INSTANCE.shm_open(this.memoryName, O_RDONLY, 0700);
if (checkhmFd < 0) {
unlinked = true;
return;
}

// Unmap the shared memory
if (this.pSharedMemory != Pointer.NULL && INSTANCE.munmap(this.pSharedMemory, size) == -1) {
if (this.pSharedMemory != Pointer.NULL && this.useLibRT && INSTANCE_RT.munmap(this.pSharedMemory, size) == -1) {
throw new RuntimeException("munmap failed. Errno: " + Native.getLastError());
} else if (this.pSharedMemory != Pointer.NULL && !this.useLibRT && INSTANCE_C.munmap(this.pSharedMemory, size) == -1) {
throw new RuntimeException("munmap failed. Errno: " + Native.getLastError());
}

// Close the file descriptor
if (INSTANCE.close(this.shmFd) == -1) {
if (this.useLibRT && INSTANCE_RT.close(this.shmFd) == -1) {
throw new RuntimeException("close failed. Errno: " + Native.getLastError());
} else if (!this.useLibRT && INSTANCE_C.close(this.shmFd) == -1) {
throw new RuntimeException("close failed. Errno: " + Native.getLastError());
}

// Unlink the shared memory object
INSTANCE.shm_unlink(this.memoryName);
if (this.useLibRT) INSTANCE_RT.shm_unlink(memoryName);
else INSTANCE_C.shm_unlink(memoryName);
unlinked = true;
}

Expand Down Expand Up @@ -628,12 +657,19 @@ RandomAccessibleInterval<T> createImgLib2RaiFromSharedMemoryBlock(String memoryN
}
try {
RandomAccessibleInterval<T> rai = buildFromSharedMemoryBlock(pSharedMemory, shape, isFortran, dataType);
/** TODO decide
/** TODO decide
/** TODO decide
/** TODO decide
/** TODO decide
/** TODO decide
if (pSharedMemory != Pointer.NULL) {
INSTANCE.munmap(pSharedMemory, size);
}
if (shmFd >= 0) {
INSTANCE.close(shmFd);
}
*/
return rai;
} catch (Exception ex) {
if (pSharedMemory != Pointer.NULL) {
Expand Down

0 comments on commit 3ed15ff

Please sign in to comment.