diff --git a/ts/torch_handler/request_envelope/kservev2.py b/ts/torch_handler/request_envelope/kservev2.py index d975c1a946..1affd1949e 100644 --- a/ts/torch_handler/request_envelope/kservev2.py +++ b/ts/torch_handler/request_envelope/kservev2.py @@ -113,6 +113,14 @@ def _from_json(self, body_list): ) input_names.append(input["name"]) setattr(self.context, "input_names", input_names) + + output_names = [] + for index, output in enumerate(body_list[0].get("outputs", [])): + output_names.append(output["name"]) + # TODO: Add parameters support + # parameters = output.get("parameters") + setattr(self.context, "output_names", output_names) + logger.debug("Bytes array is %s", body_list) id = body_list[0].get("id") if id and id.strip(): @@ -167,10 +175,15 @@ def _batch_to_json(self, data): Splits batch output to json objects """ output = [] - input_names = getattr(self.context, "input_names") + + output_names = getattr(self.context, "output_names") + delattr(self.context, "output_names") + if len(output_names) == 0: + # Re-use input names in case no output is specified + output_names = getattr(self.context, "input_names") delattr(self.context, "input_names") for index, item in enumerate(data): - output.append(self._to_json(item, input_names[index])) + output.append(self._to_json(item, output_names[index])) return output def _to_json(self, data, input_name):