Skip to content

Commit

Permalink
correct gigabug creating tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 23, 2024
1 parent d951617 commit f33dc43
Showing 1 changed file with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

import org.bytedeco.pytorch.Tensor;

import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;

/**
* A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for
Expand Down Expand Up @@ -79,7 +83,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -90,7 +94,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -101,7 +105,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -112,7 +116,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
if (PlatformDetection.isWindows()) shma.close();
}
Expand All @@ -123,7 +127,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
if (PlatformDetection.isWindows()) shma.close();
}
Expand Down

0 comments on commit f33dc43

Please sign in to comment.