Skip to content

Commit

Permalink
add tools to move to persistent interprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 24, 2024
1 parent a55ce82 commit fbd3db9
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 1,511 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.tensorflow.v2.api020.shm.ShmBuilder;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.ZipUtils;
Expand All @@ -50,24 +49,14 @@
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.RandomAccessFile;
import java.io.UnsupportedEncodingException;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.ProtectionDomain;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand All @@ -81,7 +70,6 @@
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.family.TType;

/**
* Class to that communicates with the dl-model runner, see
Expand Down Expand Up @@ -290,28 +278,28 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
Session session = model.session();
Session.Runner runner = session.runner();
List<String> inputListNames = new ArrayList<String>();
List<TType> inTensors = new ArrayList<TType>();
List<org.tensorflow.Tensor<?>> inTensors = new ArrayList<org.tensorflow.Tensor<?>>();
int c = 0;
for (Tensor<?> tt : inputTensors) {
for (Tensor<T> tt : inputTensors) {
inputListNames.add(tt.getName());
TType inT = TensorBuilder.build(tt);
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
inTensors.add(inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}
c = 0;
for (Tensor<?> tt : outputTensors)
for (Tensor<R> tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();

// Fill the agnostic output tensors list with data from the inference result
fillOutputTensors(resultPatchTensors, outputTensors);
// Close the remaining resources
for (TType tt : inTensors) {
for (org.tensorflow.Tensor<?> tt : inTensors) {
tt.close();
}
for (org.tensorflow.Tensor tt : resultPatchTensors) {
for (org.tensorflow.Tensor<?> tt : resultPatchTensors) {
tt.close();
}
}
Expand All @@ -320,12 +308,12 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
Session session = model.session();
Session.Runner runner = session.runner();

List<TType> inTensors = new ArrayList<TType>();
List<org.tensorflow.Tensor<?>> inTensors = new ArrayList<org.tensorflow.Tensor<?>>();
int c = 0;
for (String ee : inputs) {
Map<String, Object> decoded = Types.decode(ee);
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v2.api020.shm.TensorBuilder.build(shma);
if (PlatformDetection.isWindows()) shma.close();
inTensors.add(inT);
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
Expand All @@ -336,19 +324,19 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
for (String ee : outputs)
runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();

// Fill the agnostic output tensors list with data from the inference result
c = 0;
for (String ee : outputs) {
Map<String, Object> decoded = Types.decode(ee);
ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY));
ShmBuilder.build((org.tensorflow.Tensor<?>) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY));
}
// Close the remaining resources
for (TType tt : inTensors) {
for (org.tensorflow.Tensor<?> tt : inTensors) {
tt.close();
}
for (org.tensorflow.Tensor tt : resultPatchTensors) {
for (org.tensorflow.Tensor<?> tt : resultPatchTensors) {
tt.close();
}
}
Expand Down
Loading

0 comments on commit fbd3db9

Please sign in to comment.