From 27ff872c867cf3f9c7f4b86d2d89dbd0cf9ab65a Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 25 Sep 2024 00:22:30 +0200 Subject: [PATCH] fully adapt to persistent interprocessing --- .../javacpp/PytorchJavaCPPInterface.java | 54 ++--- .../javacpp/shm/NDArrayShmBuilder.java | 50 ---- .../pytorch/javacpp/shm/ShmBuilder.java | 207 ++++++++++++++++ .../pytorch/javacpp/shm/TensorBuilder.java | 224 ++++++++++++++++++ 4 files changed, 457 insertions(+), 78 deletions(-) delete mode 100644 src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java create mode 100644 src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java create mode 100644 src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java index 1d648ab..43dcf40 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java @@ -20,17 +20,13 @@ */ package io.bioimage.modelrunner.pytorch.javacpp; -import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; -import java.lang.reflect.Type; import java.net.URISyntaxException; import java.net.URL; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; -import java.nio.file.FileAlreadyExistsException; import java.security.ProtectionDomain; import java.util.ArrayList; import java.util.HashMap; @@ -45,7 +41,6 @@ import org.bytedeco.pytorch.TensorVector; import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; import io.bioimage.modelrunner.apposed.appose.Service; import io.bioimage.modelrunner.apposed.appose.Types; @@ -54,18 +49,17 @@ import io.bioimage.modelrunner.engine.DeepLearningEngineInterface; import io.bioimage.modelrunner.exceptions.LoadModelException; import io.bioimage.modelrunner.exceptions.RunModelException; -import io.bioimage.modelrunner.pytorch.javacpp.shm.NDArrayShmBuilder; import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; import io.bioimage.modelrunner.pytorch.javacpp.tensor.JavaCPPTensorBuilder; +import io.bioimage.modelrunner.pytorch.javacpp.shm.ShmBuilder; +import io.bioimage.modelrunner.pytorch.javacpp.shm.TensorBuilder; import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; -import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Cast; import net.imglib2.util.Util; @@ -239,37 +233,41 @@ void run(List> inputTensors, List> outputTensors) throws Run protected void runFromShmas(List inputs, List outputs) throws IOException { - List inTensors = new ArrayList(); - int c = 0; + IValueVector inputsVector = new IValueVector(); for (String ee : inputs) { Map decoded = Types.decode(ee); SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY)); - TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma); + org.bytedeco.pytorch.Tensor inT = TensorBuilder.build(shma); + inputsVector.put(new IValue(inT)); if (PlatformDetection.isWindows()) shma.close(); - inTensors.add(inT); - String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++); - runner.feed(inputName, inT); } - - c = 0; - for (String ee : outputs) - runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++)); - // Run runner - List resultPatchTensors = runner.run(); + // Run model + model.eval(); + IValue output = model.forward(inputsVector); + TensorVector outputTensorVector = null; + if (output.isTensorList()) { + outputTensorVector = output.toTensorVector(); + } else { + outputTensorVector = new TensorVector(); + outputTensorVector.put(output.toTensor()); + } // Fill the agnostic output tensors list with data from the inference result - c = 0; + int c = 0; for (String ee : outputs) { Map decoded = Types.decode(ee); - ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY)); - } - // Close the remaining resources - for (TType tt : inTensors) { - tt.close(); + ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY)); } - for (org.tensorflow.Tensor tt : resultPatchTensors) { - tt.close(); + outputTensorVector.close(); + outputTensorVector.deallocate(); + output.close(); + output.deallocate(); + for (int i = 0; i < inputsVector.size(); i ++) { + inputsVector.get(i).close(); + inputsVector.get(i).deallocate(); } + inputsVector.close(); + inputsVector.deallocate(); } /** diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java deleted file mode 100644 index a8b44fe..0000000 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/NDArrayShmBuilder.java +++ /dev/null @@ -1,50 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java API for Pytorch. This project uses Pytorch with thanks to JavaCPP - * %% - * Copyright (C) 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ - -package io.bioimage.modelrunner.pytorch.javacpp.shm; - -import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; -import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; -import net.imglib2.util.Cast; - -import java.io.IOException; - -/** - * A helper class to build {@link SharedMemoryArray} from {@link org.bytedeco.pytorch.Tensor} - * - * @author Carlos Garcia Lopez de Haro - */ -public class NDArrayShmBuilder { - - - /** - * Build a shared memory segment from a Pytorch tensor - * @param tensor - * the Pytorch tensor created using JavaCPP - * @param memoryName - * the shared memory region name - * @return the {@link SharedMemoryArray} object created - * @throws IOException if there is any error creating the shared memory segment - */ - public static SharedMemoryArray buildShma(org.bytedeco.pytorch.Tensor tensor, String memoryName) throws IOException { - return SharedMemoryArray.createSHMAFromRAI(memoryName, Cast.unchecked(ImgLib2Builder.build(tensor)), false, true); - } -} diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java new file mode 100644 index 0000000..e5ac0c1 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java @@ -0,0 +1,207 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.pytorch.javacpp.shm; + +import io.bioimage.modelrunner.system.PlatformDetection; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import org.bytedeco.pytorch.Tensor; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; + +/** + * A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects. + * Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor}) + * from Tensorflow 2 {@link Tensor} + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class ShmBuilder +{ + /** + * Utility class. + */ + private ShmBuilder() + { + } + + /** + * Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor + * + * @param + * the possible ImgLib2 datatypes of the image + * @param tensor + * The {@link TType} tensor data is read from. + * @throws IllegalArgumentException If the {@link TType} tensor type is not supported. + * @throws IOException + */ + public static void build(Tensor tensor, String memoryName) throws IllegalArgumentException, IOException + { + if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte) + || tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) { + buildFromTensorByte(tensor, memoryName); + } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) { + buildFromTensorInt(tensor, memoryName); + } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) { + buildFromTensorFloat(tensor, memoryName); + } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) { + buildFromTensorDouble(tensor, memoryName); + } else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) { + buildFromTensorLong(tensor, memoryName); + } else { + throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type()); + } + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor. + * + * @param tensor + * The {@link TUint8} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}. + * @throws IOException + */ + private static void buildFromTensorByte(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape(); + if (CommonUtils.int32Overflows(arrayShape, 1)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + byte[] flatArr = new byte[(int) flatSize]; + tensor.data_ptr_byte().get(flatArr); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor. + * + * @param tensor + * The {@link TInt32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}. + * @throws IOException + */ + private static void buildFromTensorInt(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + int[] flatArr = new int[(int) flatSize]; + tensor.data_ptr_int().get(flatArr); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Integer.BYTES); + byteBuffer.asIntBuffer().put(flatArr); + shma.setBuffer(byteBuffer); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor. + * + * @param tensor + * The {@link TFloat32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}. + * @throws IOException + */ + private static void buildFromTensorFloat(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + float[] flatArr = new float[(int) flatSize]; + // TODO check what we get with tensor.data_ptr_byte().get(flatArr); + tensor.data_ptr_float().get(flatArr); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Float.BYTES); + byteBuffer.asFloatBuffer().put(flatArr); + shma.setBuffer(byteBuffer); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor. + * + * @param tensor + * The {@link TFloat64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}. + * @throws IOException + */ + private static void buildFromTensorDouble(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + double[] flatArr = new double[(int) flatSize]; + tensor.data_ptr_double().get(flatArr); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Double.BYTES); + byteBuffer.asDoubleBuffer().put(flatArr); + shma.setBuffer(byteBuffer); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor. + * + * @param tensor + * The {@link TInt64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}. + * @throws IOException + */ + private static void buildFromTensorLong(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); + long flatSize = 1; + for (long l : arrayShape) {flatSize *= l;} + long[] flatArr = new long[(int) flatSize]; + tensor.data_ptr_long().get(flatArr); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer byteBuffer = ByteBuffer.allocate(flatArr.length * Long.BYTES); + byteBuffer.asLongBuffer().put(flatArr); + shma.setBuffer(byteBuffer); + if (PlatformDetection.isWindows()) shma.close(); + } +} diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java new file mode 100644 index 0000000..71fe77b --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java @@ -0,0 +1,224 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +package io.bioimage.modelrunner.pytorch.javacpp.shm; + +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; + +import org.bytedeco.pytorch.Tensor; + +/** + * A TensorFlow 2 {@link Tensor} builder from {@link Img} and + * {@link io.bioimage.modelrunner.tensor.Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class TensorBuilder { + + /** + * Utility class. + */ + private TensorBuilder() {} + + /** + * Creates {@link TType} instance with the same size and information as the + * given {@link RandomAccessibleInterval}. + * + * @param + * the ImgLib2 data types the {@link RandomAccessibleInterval} can be + * @param array + * the {@link RandomAccessibleInterval} that is going to be converted into + * a {@link TType} tensor + * @return a {@link TType} tensor + * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval} + * is not supported + */ + public static org.bytedeco.pytorch.Tensor build(SharedMemoryArray array) throws IllegalArgumentException + { + // Create an Icy sequence of the same type of the tensor + if (array.getOriginalDataType().equals("int8")) { + return buildByte(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int32")) { + return buildInt(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float32")) { + return buildFloat(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float64")) { + return buildDouble(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int64")) { + return buildLong(Cast.unchecked(array)); + } + else { + throw new IllegalArgumentException("Unsupported tensor type: " + array.getOriginalDataType()); + } + } + + /** + * Creates a {@link TType} tensor of type {@link TUint8} from an + * {@link RandomAccessibleInterval} of type {@link UnsignedByteType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static org.bytedeco.pytorch.Tensor buildByte(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + Tensor ndarray = Tensor.create(buff.array(), ogShape); + return ndarray; + } + + /** + * Creates a {@link TInt32} tensor of type {@link TInt32} from an + * {@link RandomAccessibleInterval} of type {@link IntType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildInt(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + IntBuffer intBuff = buff.asIntBuffer(); + int[] intArray = new int[intBuff.capacity()]; + intBuff.get(intArray); + Tensor ndarray = Tensor.create(intBuff.array(), ogShape); + return ndarray; + } + + /** + * Creates a {@link TInt64} tensor of type {@link TInt64} from an + * {@link RandomAccessibleInterval} of type {@link LongType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + LongBuffer longBuff = buff.asLongBuffer(); + long[] longArray = new long[longBuff.capacity()]; + longBuff.get(longArray); + Tensor ndarray = Tensor.create(longBuff.array(), ogShape); + return ndarray; + } + + /** + * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an + * {@link RandomAccessibleInterval} of type {@link FloatType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + FloatBuffer floatBuff = buff.asFloatBuffer(); + float[] floatArray = new float[floatBuff.capacity()]; + floatBuff.get(floatArray); + Tensor ndarray = Tensor.create(floatBuff.array(), ogShape); + return ndarray; + } + + /** + * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an + * {@link RandomAccessibleInterval} of type {@link DoubleType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + DoubleBuffer doubleBuff = buff.asDoubleBuffer(); + double[] doubleArray = new double[doubleBuff.capacity()]; + doubleBuff.get(doubleArray); + Tensor ndarray = Tensor.create(doubleBuff.array(), ogShape); + return ndarray; + } +}