diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java index 008b989..d9c5f01 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java @@ -37,10 +37,9 @@ import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.Tensor; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.tensorflow.v2.api020.shm.ShmBuilder; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2; import io.bioimage.modelrunner.utils.CommonUtils; import io.bioimage.modelrunner.utils.Constants; import io.bioimage.modelrunner.utils.ZipUtils; @@ -50,24 +49,14 @@ import net.imglib2.util.Cast; import net.imglib2.util.Util; -import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.InputStreamReader; -import java.io.RandomAccessFile; import java.io.UnsupportedEncodingException; import java.net.URISyntaxException; import java.net.URL; import java.net.URLDecoder; -import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; import java.security.ProtectionDomain; -import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; @@ -81,7 +70,6 @@ import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; -import org.tensorflow.types.family.TType; /** * Class to that communicates with the dl-model runner, see @@ -290,28 +278,28 @@ void run(List> inputTensors, List> outputTensors) Session session = model.session(); Session.Runner runner = session.runner(); List inputListNames = new ArrayList(); - List inTensors = new ArrayList(); + List> inTensors = new ArrayList>(); int c = 0; - for (Tensor tt : inputTensors) { + for (Tensor tt : inputTensors) { inputListNames.add(tt.getName()); - TType inT = TensorBuilder.build(tt); + org.tensorflow.Tensor inT = TensorBuilder.build(tt); inTensors.add(inT); String inputName = getModelInputName(tt.getName(), c ++); runner.feed(inputName, inT); } c = 0; - for (Tensor tt : outputTensors) + for (Tensor tt : outputTensors) runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); // Run runner - List resultPatchTensors = runner.run(); + List> resultPatchTensors = runner.run(); // Fill the agnostic output tensors list with data from the inference result fillOutputTensors(resultPatchTensors, outputTensors); // Close the remaining resources - for (TType tt : inTensors) { + for (org.tensorflow.Tensor tt : inTensors) { tt.close(); } - for (org.tensorflow.Tensor tt : resultPatchTensors) { + for (org.tensorflow.Tensor tt : resultPatchTensors) { tt.close(); } } @@ -320,12 +308,12 @@ protected void runFromShmas(List inputs, List outputs) throws IO Session session = model.session(); Session.Runner runner = session.runner(); - List inTensors = new ArrayList(); + List> inTensors = new ArrayList>(); int c = 0; for (String ee : inputs) { Map decoded = Types.decode(ee); SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY)); - TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma); + org.tensorflow.Tensor inT = io.bioimage.modelrunner.tensorflow.v2.api020.shm.TensorBuilder.build(shma); if (PlatformDetection.isWindows()) shma.close(); inTensors.add(inT); String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++); @@ -336,19 +324,19 @@ protected void runFromShmas(List inputs, List outputs) throws IO for (String ee : outputs) runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++)); // Run runner - List resultPatchTensors = runner.run(); + List> resultPatchTensors = runner.run(); // Fill the agnostic output tensors list with data from the inference result c = 0; for (String ee : outputs) { Map decoded = Types.decode(ee); - ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY)); + ShmBuilder.build((org.tensorflow.Tensor) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY)); } // Close the remaining resources - for (TType tt : inTensors) { + for (org.tensorflow.Tensor tt : inTensors) { tt.close(); } - for (org.tensorflow.Tensor tt : resultPatchTensors) { + for (org.tensorflow.Tensor tt : resultPatchTensors) { tt.close(); } } diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java deleted file mode 100644 index 64adaa7..0000000 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java +++ /dev/null @@ -1,869 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java 0.2.0 API for Tensorflow 2. - * %% - * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.tensorflow.v2.api020; - -import com.google.protobuf.InvalidProtocolBufferException; - -import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; -import io.bioimage.modelrunner.bioimageio.download.DownloadModel; -import io.bioimage.modelrunner.engine.DeepLearningEngineInterface; -import io.bioimage.modelrunner.engine.EngineInfo; -import io.bioimage.modelrunner.exceptions.LoadModelException; -import io.bioimage.modelrunner.exceptions.RunModelException; -import io.bioimage.modelrunner.system.PlatformDetection; -import io.bioimage.modelrunner.tensor.Tensor; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer; -import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2; -import io.bioimage.modelrunner.utils.Constants; -import io.bioimage.modelrunner.utils.ZipUtils; -import net.imglib2.type.NativeType; -import net.imglib2.type.numeric.RealType; - -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.RandomAccessFile; -import java.io.UnsupportedEncodingException; -import java.net.URISyntaxException; -import java.net.URL; -import java.net.URLDecoder; -import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.security.ProtectionDomain; -import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.stream.Collectors; - -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.proto.framework.MetaGraphDef; -import org.tensorflow.proto.framework.SignatureDef; -import org.tensorflow.proto.framework.TensorInfo; - -/** - * Class to that communicates with the dl-model runner, see - * @see dlmodelrunner - * to execute Tensorflow 2 models. This class is compatible with TF2 Java API 0.2.0. - * This class implements the interface {@link DeepLearningEngineInterface} to get the - * agnostic {@link io.bioimage.modelrunner.tensor.Tensor}, convert them into - * {@link org.tensorflow.Tensor}, execute a Tensorflow 2 Deep Learning model on them and - * convert the results back to {@link io.bioimage.modelrunner.tensor.Tensor} to send them - * to the main program in an agnostic manner. - * - * {@link ImgLib2Builder}. Creates ImgLib2 images for the backend - * of {@link io.bioimage.modelrunner.tensor.Tensor} from {@link org.tensorflow.Tensor} - * {@link TensorBuilder}. Converts {@link io.bioimage.modelrunner.tensor.Tensor} into {@link org.tensorflow.Tensor} - * - * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando - */ -public class Tensorflow2Interface2_old implements DeepLearningEngineInterface { - - private static final String[] MODEL_TAGS = { "serve", "inference", "train", - "eval", "gpu", "tpu" }; - - private static final String[] TF_MODEL_TAGS = { - "tf.saved_model.tag_constants.SERVING", - "tf.saved_model.tag_constants.INFERENCE", - "tf.saved_model.tag_constants.TRAINING", - "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", - "tf.saved_model.tag_constants.TPU" }; - - private static final String[] SIGNATURE_CONSTANTS = { "serving_default", - "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs", - "tensorflow/serving/predict", "outputs", "inputs", - "tensorflow/serving/regress", "outputs", "train", "eval", - "tensorflow/supervised/training", "tensorflow/supervised/eval" }; - - private static final String[] TF_SIGNATURE_CONSTANTS = { - "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.CLASSIFY_INPUTS", - "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", - "tf.saved_model.signature_constants.PREDICT_INPUTS", - "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", - "tf.saved_model.signature_constants.PREDICT_OUTPUTS", - "tf.saved_model.signature_constants.REGRESS_INPUTS", - "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", - "tf.saved_model.signature_constants.REGRESS_OUTPUTS", - "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", - "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" }; - - /** - * Idetifier for the files that contain the data of the inputs - */ - final private static String INPUT_FILE_TERMINATION = "_model_input"; - - /** - * Idetifier for the files that contain the data of the outputs - */ - final private static String OUTPUT_FILE_TERMINATION = "_model_output"; - /** - * Key for the inputs in the map that retrieves the file names for interprocess communication - */ - final private static String INPUTS_MAP_KEY = "inputs"; - /** - * Key for the outputs in the map that retrieves the file names for interprocess communication - */ - final private static String OUTPUTS_MAP_KEY = "outputs"; - /** - * File extension for the temporal files used for interprocessing - */ - final private static String FILE_EXTENSION = ".dat"; - /** - * Name without vesion of the JAR created for this library - */ - private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-"; - - /** - * The loaded Tensorflow 2 model - */ - private static SavedModelBundle model; - /** - * Internal object of the Tensorflow model - */ - private static SignatureDef sig; - /** - * Whether the execution needs interprocessing (MacOS Interl) or not - */ - private boolean interprocessing = false; - /** - * TEmporary dir where to store temporary files - */ - private String tmpDir; - /** - * Folde containing the model that is being executed - */ - private String modelFolder; - /** - * List of temporary files used for interprocessing communication - */ - private List listTempFiles; - /** - * HashMap that maps tensor to the temporal file name for interprocessing - */ - private HashMap tensorFilenameMap; - /** - * Process where the model is being loaded and executed - */ - Process process; - - /** - * TODO the interprocessing is executed for every OS - * Constructor that detects whether the operating system where it is being - * executed is Windows or Mac or not to know if it is going to need interprocessing - * or not - * @throws IOException if the temporary dir is not found - */ - public Tensorflow2Interface2_old() throws IOException - { - boolean isWin = PlatformDetection.isWindows(); - boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); - if (true || (isWin && isIntel)) { - interprocessing = true; - tmpDir = getTemporaryDir(); - } - } - - /** - * Private constructor that can only be launched from the class to create a separate - * process to avoid the conflicts that occur in the same process between TF2 and TF1/Pytorch - * in Windows and Mac - * @param doInterprocessing - * whether to do interprocessing or not - * @throws IOException if the temp dir is not found - */ - private Tensorflow2Interface2_old(boolean doInterprocessing) throws IOException - { - if (!doInterprocessing) { - interprocessing = false; - } else { - boolean isWin = PlatformDetection.isMacOS(); - boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); - if (isWin && isIntel) { - interprocessing = true; - tmpDir = getTemporaryDir(); - } - } - } - - /** - * {@inheritDoc} - * - * Load a Tensorflow 2 model. If the machine where the code is - * being executed is a MacOS Intel or Windows, the model will be loaded in - * a separate process each time the method {@link #run(List, List)} - * is called - */ - @Override - public void loadModel(String modelFolder, String modelSource) - throws LoadModelException - { - this.modelFolder = modelFolder; - if (interprocessing) - return; - try { - checkModelUnzipped(); - } catch (Exception e) { - throw new LoadModelException(e.toString()); - } - model = SavedModelBundle.load(this.modelFolder, "serve"); - byte[] byteGraph = model.metaGraphDef().toByteArray(); - try { - sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefOrThrow( - "serving_default"); - } - catch (InvalidProtocolBufferException e) { - System.out.println("Invalid graph"); - } - } - - /** - * Check if an unzipped tensorflow model exists in the model folder, - * and if not look for it and unzip it - * @throws LoadModelException if no model is found - * @throws IOException if there is any error unzipping the model - * @throws Exception if there is any error related to model packaging - */ - private void checkModelUnzipped() throws LoadModelException, IOException, Exception { - if (new File(modelFolder, "variables").isDirectory() - && new File(modelFolder, "saved_model.pb").isFile()) - return; - unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); - } - - /** - * Method that unzips the tensorflow model zip into the variables - * folder and .pb file, if they are saved in a zip - * @throws LoadModelException if not zip file is found - * @throws IOException if there is any error unzipping - */ - private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelException, IOException { - if (new File(modelFolder, "tf_weights.zip").isFile()) { - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(modelFolder + File.separator + "tf_weights.zip", modelFolder); - } else if ( descriptor.getWeights().getAllSuportedWeightNames() - .contains(EngineInfo.getBioimageioTfKey()) ) { - String source = descriptor.getWeights().gettAllSupportedWeightObjects().stream() - .filter(ww -> ww.getFramework().equals(EngineInfo.getBioimageioTfKey())) - .findFirst().get().getSource(); - if (new File(source).isFile()) { - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(new File(source).getAbsolutePath(), modelFolder); - } else if (new File(modelFolder, source).isFile()) { - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(new File(modelFolder, source).getAbsolutePath(), modelFolder); - } else { - source = DownloadModel.getFileNameFromURLString(source); - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); - } - } else { - throw new LoadModelException("No model file was found in the model folder"); - } - } - - /** - * {@inheritDoc} - * - * Run a Tensorflow2 model on the data provided by the {@link Tensor} input list - * and modifies the output list with the results obtained - */ - @Override - public void run(List> inputTensors, List> outputTensors) - throws RunModelException - { - if (interprocessing) { - runInterprocessing(inputTensors, outputTensors); - return; - } - Session session = model.session(); - Session.Runner runner = session.runner(); - List inputListNames = new ArrayList(); - List> inTensors = - new ArrayList>(); - int c = 0; - for (Tensor tt : inputTensors) { - inputListNames.add(tt.getName()); - org.tensorflow.Tensor inT = TensorBuilder.build(tt); - inTensors.add(inT); - String inputName = getModelInputName(tt.getName(), c ++); - runner.feed(inputName, inT); - } - c = 0; - for (Tensor tt : outputTensors) - runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); - // Run runner - List> resultPatchTensors = runner.run(); - - // Fill the agnostic output tensors list with data from the inference result - fillOutputTensors(resultPatchTensors, outputTensors); - // Close the remaining resources - session.close(); - for (org.tensorflow.Tensor tt : inTensors) { - tt.close(); - } - for (org.tensorflow.Tensor tt : resultPatchTensors) { - tt.close(); - } - } - - /** - * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements - * to create another process, communicate the model info and tensors to the other - * process and then retrieve the results of the other process - * @param inputTensors - * tensors that are going to be run on the model - * @param outputTensors - * expected results of the model - * @throws RunModelException if there is any issue running the model - */ - public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { - createTensorsForInterprocessing(inputTensors); - createTensorsForInterprocessing(outputTensors); - try { - List args = getProcessCommandsWithoutArgs(); - for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);} - for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);} - ProcessBuilder builder = new ProcessBuilder(args); - builder.redirectOutput(ProcessBuilder.Redirect.INHERIT); - builder.redirectError(ProcessBuilder.Redirect.INHERIT); - process = builder.start(); - int result = process.waitFor(); - process.destroy(); - if (result != 0) - throw new RunModelException("Error executing the Tensorflow 2 model in" - + " a separate process. The process was not terminated correctly." - + System.lineSeparator() + readProcessStringOutput(process)); - } catch (RunModelException e) { - closeModel(); - throw e; - } catch (Exception e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - - retrieveInterprocessingTensors(outputTensors); - } - - /** - * Create the list a list of output tensors agnostic to the Deep Learning - * engine that can be readable by the dl-modelrunner - * - * @param outputTfTensors an List containing dl-modelrunner tensors - * @param outputTensors the names given to the tensors by the model - * @throws RunModelException If the number of tensors expected is not the same - * as the number of Tensors outputed by the model - */ - public static void fillOutputTensors( - List> outputTfTensors, - List> outputTensors) throws RunModelException - { - if (outputTfTensors.size() != outputTensors.size()) - throw new RunModelException(outputTfTensors.size(), outputTensors.size()); - for (int i = 0; i < outputTfTensors.size(); i++) { - outputTensors.get(i).setData(ImgLib2Builder.build(outputTfTensors.get(i))); - } - } - - /** - * {@inheritDoc} - * - * Close the Tensorflow 2 {@link #model} and {@link #sig}. For - * MacOS Intel and Windows systems it also deletes the temporary files created to - * communicate with the other process - */ - @Override - public void closeModel() { - sig = null; - if (model != null) { - model.session().close(); - model.close(); - } - model = null; - if (listTempFiles == null) - return; - for (File ff : listTempFiles) { - if (ff.exists()) - ff.delete(); - } - listTempFiles = null; - if (process != null) - process.destroyForcibly(); - } - - // TODO make only one - /** - * Retrieves the readable input name from the graph signature definition given - * the signature input name. - * - * @param inputName Signature input name. - * @param i position of the input of interest in the list of inputs - * @return The readable input name. - */ - public static String getModelInputName(String inputName, int i) { - TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null); - if (inputInfo == null) { - inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i); - } - if (inputInfo != null) { - String modelInputName = inputInfo.getName(); - if (modelInputName != null) { - if (modelInputName.endsWith(":0")) { - return modelInputName.substring(0, modelInputName.length() - 2); - } - else { - return modelInputName; - } - } - else { - return inputName; - } - } - return inputName; - } - - /** - * Retrieves the readable output name from the graph signature definition - * given the signature output name. - * - * @param outputName Signature output name. - * @param i position of the input of interest in the list of inputs - * @return The readable output name. - */ - public static String getModelOutputName(String outputName, int i) { - TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null); - if (outputInfo == null) { - outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i); - } - if (outputInfo != null) { - String modelOutputName = outputInfo.getName(); - if (modelOutputName.endsWith(":0")) { - return modelOutputName.substring(0, modelOutputName.length() - 2); - } - else { - return modelOutputName; - } - } - else { - return outputName; - } - } - - - /** - * Methods to run interprocessing and bypass the errors that occur in MacOS intel - * with the compatibility between TF2 and TF1/Pytorch - * This method checks that the arguments are correct, retrieves the input and output - * tensors, loads the model, makes inference with it and finally sends the tensors - * to the original process - * - * @param args - * arguments of the program: - * - Path to the model folder - * - Path to a temporary dir - * - Name of the input 0 - * - Name of the input 1 - * - ... - * - Name of the output n - * - Name of the output 0 - * - Name of the output 1 - * - ... - * - Name of the output n - * @throws LoadModelException if there is any error loading the model - * @throws IOException if there is any error reading or writing any file or with the paths - * @throws RunModelException if there is any error running the model - */ - public static void main(String[] args) throws LoadModelException, IOException, RunModelException { - Tensorflow2Interface2_old tt = new Tensorflow2Interface2_old(false); - - tt.loadModel("/home/carlos/Desktop/Fiji.app/models/model_03bioimageio", null); - // Unpack the args needed - if (args.length < 4) - throw new IllegalArgumentException("Error exectuting Tensorflow 2, " - + "at least 5 arguments are required:" + System.lineSeparator() - + " - Folder where the model is located" + System.lineSeparator() - + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() - + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() - + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - ); - String modelFolder = args[0]; - if (!(new File(modelFolder).isDirectory())) { - throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " - + "should be an existing directory containing a Tensorflow 2 model."); - } - - Tensorflow2Interface2_old tfInterface = new Tensorflow2Interface2_old(false); - tfInterface.tmpDir = args[1]; - if (!(new File(args[1]).isDirectory())) { - throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " - + "should be an existing directory."); - } - - tfInterface.loadModel(modelFolder, modelFolder); - - HashMap> map = tfInterface.getInputTensorsFileNames(args); - List inputNames = map.get(INPUTS_MAP_KEY); - List> inputList = inputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - List outputNames = map.get(OUTPUTS_MAP_KEY); - List> outputList = outputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - tfInterface.run(inputList, outputList); - tfInterface.createTensorsForInterprocessing(outputList); - } - - /** - * Get the name of the temporary file associated to the tensor name - * @param name - * name of the tensor - * @return file name associated to the tensor - */ - private String getFilename4Tensor(String name) { - if (tensorFilenameMap == null) - tensorFilenameMap = new HashMap(); - if (tensorFilenameMap.get(name) != null) - return tensorFilenameMap.get(name); - LocalDateTime now = LocalDateTime.now(); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS"); - String newName = name + "_" + now.format(formatter); - tensorFilenameMap.put(name, newName); - return tensorFilenameMap.get(name); - } - - /** - * Create a temporary file for each of the tensors in the list to communicate with - * the separate process in MacOS Intel and Windows systems - * @param tensors - * list of tensors to be sent - * @throws RunModelException if there is any error converting the tensors - */ - private void createTensorsForInterprocessing(List> tensors) throws RunModelException{ - if (this.listTempFiles == null) - this.listTempFiles = new ArrayList(); - for (Tensor tensor : tensors) { - long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor); - File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION); - if (!ff.exists()) { - ff.deleteOnExit(); - this.listTempFiles.add(ff); - } - try (RandomAccessFile rd = - new RandomAccessFile(ff, "rw"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); - ByteBuffer byteBuffer = mem.duplicate(); - ImgLib2ToMappedBuffer.build(tensor, byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Retrieves the data of the tensors contained in the input list from the output - * generated by the independent process - * @param tensors - * list of tensors that are going to be filled - * @throws RunModelException if there is any issue retrieving the data from the other process - */ - private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{ - for (Tensor tensor : tensors) { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Create a tensor from the data contained in a file named as the parameter - * provided as an input + the file extension {@link #FILE_EXTENSION}. - * This file is produced by another process to communicate with the current process - * @param - * generic type of the tensor - * @param name - * name of the file without the extension ({@link #FILE_EXTENSION}). - * @return a tensor created with the data in the file - * @throws RunModelException if there is any problem retrieving the data and cerating the tensor - */ - private < T extends RealType< T > & NativeType< T > > Tensor - retrieveInterprocessingTensorsByName(String name) throws RunModelException { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(name) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - return MappedBufferToImgLib2.buildTensor(byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - - /** - * if java bin dir contains any special char, surround it by double quotes - * @param javaBin - * java bin dir - * @return impored java bin dir if needed - */ - private static String padSpecialJavaBin(String javaBin) { - String[] specialChars = new String[] {" "}; - for (String schar : specialChars) { - if (javaBin.contains(schar) && PlatformDetection.isWindows()) { - return "\"" + javaBin + "\""; - } - } - return javaBin; - } - - /** - * Create the arguments needed to execute tensorflow 2 in another - * process with the corresponding tensors - * @return the command used to call the separate process - * @throws IOException if the command needed to execute interprocessing is too long - * @throws URISyntaxException if there is any error with the URIs retrieved from the classes - */ - private List getProcessCommandsWithoutArgs() throws IOException, URISyntaxException { - String javaHome = System.getProperty("java.home"); - String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; - - String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); - String imglib2Path = getPathFromClass(NativeType.class); - if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") - && !modelrunnerPath.contains(File.pathSeparator))) - modelrunnerPath = System.getProperty("java.class.path"); - String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; - ProtectionDomain protectionDomain = Tensorflow2Interface2_old.class.getProtectionDomain(); - String codeSource = protectionDomain.getCodeSource().getLocation().getPath(); - String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString()); - f_name = new File(f_name).getAbsolutePath(); - for (File ff : new File(f_name).getParentFile().listFiles()) { - if (ff.getName().startsWith(JAR_FILE_NAME) && !ff.getAbsolutePath().equals(f_name)) - continue; - classpath += ff.getAbsolutePath() + File.pathSeparator; - } - String className = Tensorflow2Interface2_old.class.getName(); - List command = new LinkedList(); - command.add(padSpecialJavaBin(javaBin)); - command.add("-cp"); - command.add(classpath); - command.add(className); - command.add(modelFolder); - command.add(this.tmpDir); - return command; - } - - /** - * Method that gets the path to the JAR from where a specific class is being loaded - * @param clazz - * class of interest - * @return the path to the JAR that contains the class - * @throws UnsupportedEncodingException if the url of the JAR is not encoded in UTF-8 - */ - private static String getPathFromClass(Class clazz) throws UnsupportedEncodingException { - String classResource = clazz.getName().replace('.', '/') + ".class"; - URL resourceUrl = clazz.getClassLoader().getResource(classResource); - if (resourceUrl == null) { - return null; - } - String urlString = resourceUrl.toString(); - if (urlString.startsWith("jar:")) { - urlString = urlString.substring(4); - } - if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) { - urlString = urlString.substring(6); - } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) { - urlString = urlString.substring(5); - } - urlString = URLDecoder.decode(urlString, "UTF-8"); - File file = new File(urlString); - String path = file.getAbsolutePath(); - if (path.lastIndexOf(".jar!") != -1) - path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; - return path; - } - - /** - * Get temporary directory to perform the interprocessing communication in MacOSX - * intel and Windows - * @return the tmp dir - * @throws IOException if the files cannot be written in any of the temp dirs - */ - private static String getTemporaryDir() throws IOException { - String tmpDir; - String enginesDir = getEnginesDir(); - if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) { - tmpDir = enginesDir + File.separator + "temp"; - if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) - tmpDir = enginesDir; - } else if (System.getenv("temp") != null - && Files.isWritable(Paths.get(System.getenv("temp")))) { - return System.getenv("temp"); - } else if (System.getenv("TEMP") != null - && Files.isWritable(Paths.get(System.getenv("TEMP")))) { - return System.getenv("TEMP"); - } else if (System.getenv("tmp") != null - && Files.isWritable(Paths.get(System.getenv("tmp")))) { - return System.getenv("tmp"); - } else if (System.getenv("TMP") != null - && Files.isWritable(Paths.get(System.getenv("TMP")))) { - return System.getenv("TMP"); - } else if (System.getProperty("java.io.tmpdir") != null - && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { - return System.getProperty("java.io.tmpdir"); - } else { - throw new IOException("Unable to find temporal directory with writting rights. " - + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); - } - return tmpDir; - } - - /** - * GEt the directory where the TF2 engine is located if a temporary dir is not found - * @return directory of the engines - */ - private static String getEnginesDir() { - String dir; - try { - dir = getPathFromClass(Tensorflow2Interface2_old.class); - } catch (UnsupportedEncodingException e) { - String classResource = Tensorflow2Interface2_old.class.getName().replace('.', '/') + ".class"; - URL resourceUrl = Tensorflow2Interface2_old.class.getClassLoader().getResource(classResource); - if (resourceUrl == null) { - return null; - } - String urlString = resourceUrl.toString(); - if (urlString.startsWith("jar:")) { - urlString = urlString.substring(4); - } - if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) { - urlString = urlString.substring(6); - } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) { - urlString = urlString.substring(5); - } - File file = new File(urlString); - String path = file.getAbsolutePath(); - if (path.lastIndexOf(".jar!") != -1) - path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; - dir = path; - } - return new File(dir).getParent(); - } - - /** - * Retrieve the file names used for interprocess communication - * @param args - * args provided to the main method - * @return a map with a list of input and output names - */ - private HashMap> getInputTensorsFileNames(String[] args) { - List inputNames = new ArrayList(); - List outputNames = new ArrayList(); - if (this.tensorFilenameMap == null) - this.tensorFilenameMap = new HashMap(); - for (int i = 2; i < args.length; i ++) { - if (args[i].endsWith(INPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - inputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - outputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - - } - } - if (inputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow2Interface2_old.class.toString() + "' should contain at " - + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'."); - if (outputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow2Interface2_old.class.toString() + "' should contain at " - + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'."); - HashMap> map = new HashMap>(); - map.put(INPUTS_MAP_KEY, inputNames); - map.put(OUTPUTS_MAP_KEY, outputNames); - return map; - } - - /** - * MEthod to obtain the String output of the process in case something goes wrong - * @param process - * the process that executed the TF2 model - * @return the String output that we would have seen on the terminal - * @throws IOException if the output of the terminal cannot be seen - */ - private static String readProcessStringOutput(Process process) throws IOException { - BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream())); - BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); - String text = ""; - String line; - while ((line = bufferedErrReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - while ((line = bufferedReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - return text; - } -} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java new file mode 100644 index 0000000..1b5dccd --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java @@ -0,0 +1,221 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.tensorflow.v2.api020.shm; + +import io.bioimage.modelrunner.system.PlatformDetection; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import org.tensorflow.Tensor; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; + +/** + * A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects. + * Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor}) + * from Tensorflow 2 {@link Tensor} + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class ShmBuilder +{ + /** + * Utility class. + */ + private ShmBuilder() + { + } + + /** + * Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor + * + * @param + * the possible ImgLib2 datatypes of the image + * @param tensor + * The {@link TType} tensor data is read from. + * @throws IllegalArgumentException If the {@link TType} tensor type is not supported. + * @throws IOException + */ + @SuppressWarnings("unchecked") + public static void build(Tensor tensor, String memoryName) throws IllegalArgumentException, IOException + { + switch (tensor.dataType().name()) + { + case TUint8.NAME: + buildFromTensorUByte((Tensor) tensor, memoryName); + case TInt32.NAME: + buildFromTensorInt((Tensor) tensor, memoryName); + case TFloat32.NAME: + buildFromTensorFloat((Tensor) tensor, memoryName); + case TFloat64.NAME: + buildFromTensorDouble((Tensor) tensor, memoryName); + case TInt64.NAME: + buildFromTensorLong((Tensor) tensor, memoryName); + default: + throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name()); + } + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor. + * + * @param tensor + * The {@link TUint8} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}. + * @throws IOException + */ + private static void buildFromTensorUByte(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 1)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 1; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor. + * + * @param tensor + * The {@link TInt32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}. + * @throws IOException + */ + private static void buildFromTensorInt(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 4; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor. + * + * @param tensor + * The {@link TFloat32} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}. + * @throws IOException + */ + private static void buildFromTensorFloat(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 4; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor. + * + * @param tensor + * The {@link TFloat64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}. + * @throws IOException + */ + private static void buildFromTensorDouble(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 8; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } + + /** + * Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor. + * + * @param tensor + * The {@link TInt64} tensor data is read from. + * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}. + * @throws IOException + */ + private static void buildFromTensorLong(Tensor tensor, String memoryName) throws IOException + { + long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); + + + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true); + ByteBuffer buff = shma.getDataBuffer(); + int totalSize = 8; + for (long i : arrayShape) {totalSize *= i;} + byte[] flatArr = new byte[buff.capacity()]; + buff.get(flatArr); + tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize); + shma.setBuffer(ByteBuffer.wrap(flatArr)); + if (PlatformDetection.isWindows()) shma.close(); + } +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java new file mode 100644 index 0000000..7b7e84b --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java @@ -0,0 +1,242 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.3.0 and newer API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +package io.bioimage.modelrunner.tensorflow.v2.api020.shm; + +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.Arrays; + +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +/** + * A TensorFlow 2 {@link Tensor} builder from {@link Img} and + * {@link io.bioimage.modelrunner.tensor.Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public final class TensorBuilder { + + /** + * Utility class. + */ + private TensorBuilder() {} + + /** + * Creates {@link TType} instance with the same size and information as the + * given {@link RandomAccessibleInterval}. + * + * @param + * the ImgLib2 data types the {@link RandomAccessibleInterval} can be + * @param array + * the {@link RandomAccessibleInterval} that is going to be converted into + * a {@link TType} tensor + * @return a {@link TType} tensor + * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval} + * is not supported + */ + public static Tensor build(SharedMemoryArray array) throws IllegalArgumentException + { + // Create an Icy sequence of the same type of the tensor + if (array.getOriginalDataType().equals("uint8")) { + return buildUByte(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int32")) { + return buildInt(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float32")) { + return buildFloat(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("float64")) { + return buildDouble(Cast.unchecked(array)); + } + else if (array.getOriginalDataType().equals("int64")) { + return buildLong(Cast.unchecked(array)); + } + else { + throw new IllegalArgumentException("Unsupported tensor type: " + array.getOriginalDataType()); + } + } + + /** + * Creates a {@link TType} tensor of type {@link TUint8} from an + * {@link RandomAccessibleInterval} of type {@link UnsignedByteType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildUByte(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false); + Tensor ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TInt32} tensor of type {@link TInt32} from an + * {@link RandomAccessibleInterval} of type {@link IntType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildInt(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + IntBuffer intBuff = buff.asIntBuffer(); + int[] intArray = new int[intBuff.capacity()]; + intBuff.get(intArray); + IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false); + Tensor ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TInt64} tensor of type {@link TInt64} from an + * {@link RandomAccessibleInterval} of type {@link LongType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static Tensor buildLong(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + LongBuffer longBuff = buff.asLongBuffer(); + long[] longArray = new long[longBuff.capacity()]; + longBuff.get(longArray); + LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false); + Tensor ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an + * {@link RandomAccessibleInterval} of type {@link FloatType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + public static Tensor buildFloat(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + FloatBuffer floatBuff = buff.asFloatBuffer(); + float[] floatArray = new float[floatBuff.capacity()]; + floatBuff.get(floatArray); + FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false); + Tensor ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } + + /** + * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an + * {@link RandomAccessibleInterval} of type {@link DoubleType} + * + * @param tensor + * The {@link RandomAccessibleInterval} to fill the tensor with. + * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data. + * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is + * not compatible + */ + private static Tensor buildDouble(SharedMemoryArray tensor) + throws IllegalArgumentException + { + long[] ogShape = tensor.getOriginalShape(); + if (CommonUtils.int32Overflows(ogShape, 1)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); + if (!tensor.isNumpyFormat()) + throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format."); + ByteBuffer buff = tensor.getDataBufferNoHeader(); + DoubleBuffer doubleBuff = buff.asDoubleBuffer(); + double[] doubleArray = new double[doubleBuff.capacity()]; + doubleBuff.get(doubleArray); + DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false); + Tensor ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer); + return ndarray; + } +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java index 680b290..688cf23 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java @@ -26,7 +26,9 @@ import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.Img; +import net.imglib2.type.NativeType; import net.imglib2.type.Type; +import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.integer.LongType; import net.imglib2.type.numeric.integer.UnsignedByteType; @@ -76,8 +78,9 @@ private TensorBuilder() {} * @throws IllegalArgumentException If the type of the {@link io.bioimage.modelrunner.tensor.Tensor} * is not supported */ - public static Tensor build( - io.bioimage.modelrunner.tensor.Tensor tensor) + public static & NativeType> + Tensor build( + io.bioimage.modelrunner.tensor.Tensor tensor) throws IllegalArgumentException { return build(tensor.getData()); @@ -133,7 +136,7 @@ else if (Util.getTypeFromInterval(array) instanceof LongType) { private static Tensor buildUByte( RandomAccessibleInterval tensor) throws IllegalArgumentException - { + { long[] ogShape = tensor.dimensionsAsLongArray(); if (CommonUtils.int32Overflows(ogShape, 1)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java deleted file mode 100644 index 8ad57b2..0000000 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/ImgLib2ToMappedBuffer.java +++ /dev/null @@ -1,281 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java 0.2.0 API for Tensorflow 2. - * %% - * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; - -import io.bioimage.modelrunner.tensor.Tensor; -import net.imglib2.Cursor; -import net.imglib2.RandomAccessibleInterval; -import net.imglib2.type.NativeType; -import net.imglib2.type.Type; -import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.integer.ByteType; -import net.imglib2.type.numeric.integer.IntType; -import net.imglib2.type.numeric.integer.LongType; -import net.imglib2.type.numeric.integer.UnsignedByteType; -import net.imglib2.type.numeric.real.DoubleType; -import net.imglib2.type.numeric.real.FloatType; -import net.imglib2.util.Util; -import net.imglib2.view.Views; - -/** - * Class that maps {@link Tensor} objects to {@link ByteBuffer} objects. - * This is done to modify the files that are used to communicate between process - * to avoid the TF2-TF1/Pytorch incompatibility that happens in these systems - * - * @author Carlos Garcia Lopez de Haro - */ -public final class ImgLib2ToMappedBuffer -{ - /** - * Header used to identify files for interprocessing communication - */ - final public static byte[] MODEL_RUNNER_HEADER = - {(byte) 0x93, 'M', 'O', 'D', 'E', 'L', '-', 'R', 'U', 'N', 'N', 'E', 'R'}; - - /** - * Not used (Utility class). - */ - private ImgLib2ToMappedBuffer() - { - } - - /** - * Maps a {@link Tensor} to the provided {@link ByteBuffer} with all the information - * needed to reconstruct the tensor again - * - * @param - * the type of the tensor - * @param tensor - * tensor to be mapped into byte buffer - * @param byteBuffer - * target byte bufer - * @throws IllegalArgumentException - * If the {@link Tensor} ImgLib2 type is not supported. - */ - public static < T extends RealType< T > & NativeType< T > > void build(Tensor tensor, ByteBuffer byteBuffer) - { - byteBuffer.put(ImgLib2ToMappedBuffer.createFileHeader(tensor)); - if (tensor.isEmpty()) - return; - build(tensor.getData(), byteBuffer); - } - - /** - * Adds the {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param - * the type of the {@link RandomAccessibleInterval} - * @param rai - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - * @throws IllegalArgumentException If the {@link RandomAccessibleInterval} type is not supported. - */ - private static > void build(RandomAccessibleInterval rai, ByteBuffer byteBuffer) - { - if (Util.getTypeFromInterval(rai) instanceof ByteType) { - buildByte((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof IntType) { - buildInt((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof FloatType) { - buildFloat((RandomAccessibleInterval) rai, byteBuffer); - } else if (Util.getTypeFromInterval(rai) instanceof DoubleType) { - buildDouble((RandomAccessibleInterval) rai, byteBuffer); - } else { - throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString()); - } - } - - /** - * Adds the ByteType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildByte(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.put(tensorCursor.get().getByte()); - } - } - - /** - * Adds the IntType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildInt(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putInt(tensorCursor.get().getInt()); - } - } - - /** - * Adds the FloatType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildFloat(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putFloat(tensorCursor.get().getRealFloat()); - } - } - - /** - * Adds the DoubleType {@link RandomAccessibleInterval} data to the {@link ByteBuffer} provided. - * The position of the ByteBuffer is kept in the same place as it was received. - * - * @param imgTensor - * {@link RandomAccessibleInterval} to be mapped into byte buffer - * @param byteBuffer - * target bytebuffer - */ - private static void buildDouble(RandomAccessibleInterval imgTensor, ByteBuffer byteBuffer) - { - Cursor tensorCursor = Views.flatIterable(imgTensor).cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - byteBuffer.putDouble(tensorCursor.get().getRealDouble()); - } - } - - /** - * Create header for the temp file that is used for interprocess communication. - * The header should contain the first key word as an array of bytes (MODEl-RUNNER) - * @param - * type of the tensor - * @param tensor - * tensor whose info is recorded - * @return byte array containing the header info for the file - */ - public static < T extends RealType< T > & NativeType< T > > byte[] - createFileHeader(io.bioimage.modelrunner.tensor.Tensor tensor) { - String dimsStr = - !tensor.isEmpty() ? Arrays.toString(tensor.getData().dimensionsAsLongArray()) : "[]"; - T dtype = !tensor.isEmpty() ? Util.getTypeFromInterval(tensor.getData()): (T) new FloatType(); - String descriptionStr = "{'dtype':'" - + getDataTypeString(dtype) + "','axes':'" - + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" - + dimsStr + "'}"; - - byte[] descriptionBytes = descriptionStr.getBytes(StandardCharsets.UTF_8); - int lenDescriptionBytes = descriptionBytes.length; - byte[] intAsBytes = ByteBuffer.allocate(4).putInt(lenDescriptionBytes).array(); - int totalHeaderLen = MODEL_RUNNER_HEADER.length + intAsBytes.length + lenDescriptionBytes; - byte[] byteHeader = new byte[totalHeaderLen]; - for (int i = 0; i < MODEL_RUNNER_HEADER.length; i ++) - byteHeader[i] = MODEL_RUNNER_HEADER[i]; - for (int i = MODEL_RUNNER_HEADER.length; i < MODEL_RUNNER_HEADER.length + intAsBytes.length; i ++) - byteHeader[i] = intAsBytes[i - MODEL_RUNNER_HEADER.length]; - for (int i = MODEL_RUNNER_HEADER.length + intAsBytes.length; i < totalHeaderLen; i ++) - byteHeader[i] = descriptionBytes[i - MODEL_RUNNER_HEADER.length - intAsBytes.length]; - - return byteHeader; - } - - /** - * Method that returns a Sting representing the datatype of T - * @param - * type of the tensor - * @param type - * pixel of an imglib2 object to get the info of teh data type - * @return String representation of the datatype - */ - public static< T extends RealType< T > & NativeType< T > > String getDataTypeString(T type) { - if (type instanceof ByteType) { - return "byte"; - } else if (type instanceof IntType) { - return "int32"; - } else if (type instanceof FloatType) { - return "float32"; - } else if (type instanceof DoubleType) { - return "float64"; - } else if (type instanceof LongType) { - return "int64"; - } else if (type instanceof UnsignedByteType) { - return "ubyte"; - } else { - throw new IllegalArgumentException("Unsupported data type. At the moment the only " - + "supported dtypes are: " + IntType.class + ", " + FloatType.class + ", " - + DoubleType.class + ", " + LongType.class + " and " + UnsignedByteType.class); - } - } - - /** - * Get the total byte size of the temp file that is going to be created to be - * able to reconstruct a {@link Tensor} to in the separate process in MacOS Intel - * systems - * - * @param - * type of the imglib2 object - * @param tensor - * tensor of interest - * @return number of bytes needed to create a file with the info of the tensor - */ - public static < T extends RealType< T > & NativeType< T > > long - findTotalLengthFile(io.bioimage.modelrunner.tensor.Tensor tensor) { - long startLen = createFileHeader(tensor).length; - long[] dimsArr = !tensor.isEmpty() ? tensor.getData().dimensionsAsLongArray() : null; - if (dimsArr == null) - return startLen; - long totSizeFlat = 1; - for (long i : dimsArr) {totSizeFlat *= i;} - long nBytesDt = 1; - Type dtype = !tensor.isEmpty() ? - Util.getTypeFromInterval(tensor.getData()) : (Type) new FloatType(); - if (dtype instanceof IntType) { - nBytesDt = 4; - } else if (dtype instanceof ByteType) { - nBytesDt = 1; - } else if (dtype instanceof FloatType) { - nBytesDt = 4; - } else if (dtype instanceof DoubleType) { - nBytesDt = 8; - } else { - throw new IllegalArgumentException("Unsupported tensor type: " + dtype.getClass()); - } - return startLen + nBytesDt * totSizeFlat; - } -} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/MappedBufferToImgLib2.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/MappedBufferToImgLib2.java deleted file mode 100644 index 28b8c4e..0000000 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/mappedbuffer/MappedBufferToImgLib2.java +++ /dev/null @@ -1,332 +0,0 @@ -/*- - * #%L - * This project complements the DL-model runner acting as the engine that works loading models - * and making inference with Java 0.2.0 API for Tensorflow 2. - * %% - * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ -package io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import io.bioimage.modelrunner.tensor.Tensor; -import net.imglib2.Cursor; -import net.imglib2.RandomAccessibleInterval; -import net.imglib2.img.Img; -import net.imglib2.img.array.ArrayImgFactory; -import net.imglib2.type.NativeType; -import net.imglib2.type.Type; -import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.integer.ByteType; -import net.imglib2.type.numeric.integer.IntType; -import net.imglib2.type.numeric.real.DoubleType; -import net.imglib2.type.numeric.real.FloatType; - -/** - * Class that maps {@link ByteBuffer} objects to {@link Img} objects. - * This is done to modify the files that are used to communicate between process - * to avoid the TF2-TF1/Pytorch incompatibility that happens in these systems - * - * A {@link Img} builder from {@link ByteBuffer} objects - * - * @author Carlos Garcia Lopez de Haro - */ -public final class MappedBufferToImgLib2 -{ - /** - * Pattern that matches the header of the temporal file for interprocess communication - * and retrieves data type, shape, name and axes - */ - private static final Pattern HEADER_PATTERN = Pattern.compile("'dtype':'([a-zA-Z0-9]+)'" - + ",'axes':'([a-zA-Z]+)'" - + ",'name':'([^']*)'" - + ",'shape':'(\\[\\s*(?:(?:[1-9]\\d*|0)\\s*,\\s*)*(?:[1-9]\\d*|0)?\\s*\\])'"); - /** - * Key for data type info - */ - private static final String DATA_TYPE_KEY = "dtype"; - /** - * Key for shape info - */ - private static final String SHAPE_KEY = "shape"; - /** - * Key for axes info - */ - private static final String AXES_KEY = "axes"; - /** - * Key for axes info - */ - private static final String NAME_KEY = "name"; - - /** - * Not used (Utility class). - */ - private MappedBufferToImgLib2() - { - } - - /** - * Creates a {@link Tensor} from the information stored in a {@link ByteBuffer} - * - * @param - * the type of the generated tensor - * @param buff - * byte buffer to get the tensor info from - * @return the tensor generated from the bytebuffer - * @throws IllegalArgumentException if the data type of the tensor saved in the bytebuffer is - * not supported - */ - @SuppressWarnings("unchecked") - public static < T extends RealType< T > & NativeType< T > > Tensor buildTensor(ByteBuffer buff) throws IllegalArgumentException - { - String infoStr = getTensorInfoFromBuffer(buff); - HashMap map = getInfoFromHeaderString(infoStr); - String dtype = (String) map.get(DATA_TYPE_KEY); - String axes = (String) map.get(AXES_KEY); - String name = (String) map.get(NAME_KEY); - long[] shape = (long[]) map.get(SHAPE_KEY); - if (shape.length == 0) - return Tensor.buildEmptyTensor(name, axes); - - Img data; - switch (dtype) - { - case "byte": - data = (Img) buildFromTensorByte(buff, shape); - break; - case "int32": - data = (Img) buildFromTensorInt(buff, shape); - break; - case "float32": - data = (Img) buildFromTensorFloat(buff, shape); - break; - case "float64": - data = (Img) buildFromTensorDouble(buff, shape); - break; - default: - throw new IllegalArgumentException("Unsupported tensor type: " + dtype); - } - return Tensor.build(name, axes, (RandomAccessibleInterval) data); - } - - /** - * Creates a {@link Img} from the information stored in a {@link ByteBuffer} - * - * @param - * data type of the image - * @param byteBuff - * The bytebyuffer that contains info to create a tenosr or a {@link Img} - * @return The imglib2 image {@link Img} built from the bytebuffer info. - * @throws IllegalArgumentException if the data type of the tensor saved in the bytebuffer is - * not supported - */ - @SuppressWarnings("unchecked") - public static > Img build(ByteBuffer byteBuff) throws IllegalArgumentException - { - String infoStr = getTensorInfoFromBuffer(byteBuff); - HashMap map = getInfoFromHeaderString(infoStr); - String dtype = (String) map.get(DATA_TYPE_KEY); - long[] shape = (long[]) map.get(SHAPE_KEY); - if (shape.length == 0) - return null; - - // Create an INDArray of the same type of the tensor - switch (dtype) - { - case "byte": - return (Img) buildFromTensorByte(byteBuff, shape); - case "int32": - return (Img) buildFromTensorInt(byteBuff, shape); - case "float32": - return (Img) buildFromTensorFloat(byteBuff, shape); - case "float64": - return (Img) buildFromTensorDouble(byteBuff, shape); - default: - throw new IllegalArgumentException("Unsupported tensor type: " + dtype); - } - } - - /** - * Builds a ByteType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorByte(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< ByteType > factory = new ArrayImgFactory<>( new ByteType() ); - final Img< ByteType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensorCursor.get().set(tensor.get()); - } - return outputImg; - } - - /** - * Builds a IntType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorInt(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() ); - final Img< IntType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[4]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - int val = ((int) (bytes[0] << 24)) + ((int) (bytes[1] << 16)) - + ((int) (bytes[2] << 8)) + ((int) (bytes[3])); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Builds a FloatType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorFloat(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() ); - final Img< FloatType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[4]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - float val = ByteBuffer.wrap(bytes).getFloat(); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Builds a DoubleType {@link Img} from the information stored in a byte buffer. - * The shape of the image that was previously retrieved from the buffer - * @param tensor - * byte buffer containing the information of the a tenosr, the position in the buffer - * should not be at zero but right after the header. - * @param tensorShape - * shape of the image to generate, it has been retrieved from the byte buffer - * @return image specified in the bytebuffer - */ - private static Img buildFromTensorDouble(ByteBuffer tensor, long[] tensorShape) - { - final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() ); - final Img< DoubleType > outputImg = (Img) factory.create(tensorShape); - Cursor tensorCursor= outputImg.cursor(); - byte[] bytes = new byte[8]; - while (tensorCursor.hasNext()) { - tensorCursor.fwd(); - tensor.get(bytes); - double val = ByteBuffer.wrap(bytes).getDouble(); - tensorCursor.get().set(val); - } - return outputImg; - } - - /** - * Method that returns the information about the tensor specified at the - * beginning of the {@link ByteBuffer} object created - * with {@link ImgLib2ToMappedBuffer#build(Tensor,ByteBuffer)}. - * This method reads the buffer from the beginning - * @param buff - * ByteBuffer containing the information about the tensor - * @return map containing the name, axes order, datatype and shape of the tensor - * stored in teh buffer - */ - public static HashMap readHeaderAndGetInfo(ByteBuffer buff) { - buff.clear(); - return getInfoFromHeaderString(getTensorInfoFromBuffer(buff)); - } - - /** - * GEt the String info stored at the beginning of the buffer that contains - * the data type, name of tensor, axes and shape info. - * @param buff - * buffer containing all the data to generate a tensor - * @return the String header of teh bytebuffer that contains the data about - * the tensor (name, data type, shape and axes) - */ - private static String getTensorInfoFromBuffer(ByteBuffer buff) { - byte[] arr = new byte[ImgLib2ToMappedBuffer.MODEL_RUNNER_HEADER.length]; - buff.get(arr); - if (!Arrays.equals(arr, ImgLib2ToMappedBuffer.MODEL_RUNNER_HEADER)) - throw new IllegalArgumentException("Error sending tensors between processes."); - byte[] lenInfoInBytes = new byte[4]; - buff.get(lenInfoInBytes); - int lenInfo = ByteBuffer.wrap(lenInfoInBytes).getInt(); - byte[] stringInfoBytes = new byte[lenInfo]; - buff.get(stringInfoBytes); - return new String(stringInfoBytes, StandardCharsets.UTF_8); - } - - /** - * MEthod that retrieves the data type string and shape long array representing - * the data type and dimensions of the tensor saved in the temp file - * @param infoStr - * string header of the file that contains the info about the tensor - * @return dictionary containins the name, dtype, shape and axes of the tensor - */ - private static HashMap getInfoFromHeaderString(String infoStr) { - Matcher matcher = HEADER_PATTERN.matcher(infoStr); - if (!matcher.find()) { - throw new IllegalArgumentException("Cannot find datatype, name, axes and dimensions " - + "info in file header: " + infoStr); - } - String typeStr = matcher.group(1); - String axesStr = matcher.group(2); - String nameStr = matcher.group(3); - String shapeStr = matcher.group(4); - - long[] shape = new long[0]; - if (!shapeStr.isEmpty() && !shapeStr.equals("[]")) { - shapeStr = shapeStr.substring(1, shapeStr.length() - 1); - String[] tokens = shapeStr.split(", ?"); - shape = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); - } - HashMap map = new HashMap(); - map.put(DATA_TYPE_KEY, typeStr); - map.put(AXES_KEY, axesStr); - map.put(SHAPE_KEY, shape); - map.put(NAME_KEY, nameStr); - return map; - } -}