Skip to content

Commit

Permalink
support badly named inputs or outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 16, 2023
1 parent 35e6590 commit 8b6ae83
Showing 1 changed file with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,17 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
List<String> inputListNames = new ArrayList<String>();
List<org.tensorflow.Tensor<?>> inTensors =
new ArrayList<org.tensorflow.Tensor<?>>();
int c = 0;
for (Tensor tt : inputTensors) {
inputListNames.add(tt.getName());
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
inTensors.add(inT);
runner.feed(getModelInputName(tt.getName()), inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}

c = 0;
for (Tensor tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName()));
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();

Expand Down Expand Up @@ -419,10 +421,14 @@ public void closeModel() {
* the signature input name.
*
* @param inputName Signature input name.
* @param i position of the input of interest in the list of inputs
* @return The readable input name.
*/
public static String getModelInputName(String inputName) {
public static String getModelInputName(String inputName, int i) {
TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null);
if (inputInfo == null) {
inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (inputInfo != null) {
String modelInputName = inputInfo.getName();
if (modelInputName != null) {
Expand All @@ -445,10 +451,14 @@ public static String getModelInputName(String inputName) {
* given the signature output name.
*
* @param outputName Signature output name.
* @param i position of the input of interest in the list of inputs
* @return The readable output name.
*/
public static String getModelOutputName(String outputName) {
public static String getModelOutputName(String outputName, int i) {
TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null);
if (outputInfo == null) {
outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (outputInfo != null) {
String modelOutputName = outputInfo.getName();
if (modelOutputName.endsWith(":0")) {
Expand Down

0 comments on commit 8b6ae83

Please sign in to comment.