-
Notifications
You must be signed in to change notification settings - Fork 6
XII. JDLL tensors I (Tensor)
Tensors are one of the key components of Deep Learning, the network weights are tensors and the inputs an outputs are tensors. Thus a good, intuitive tensor API is crucial for any Deep Learning (DL) library. Each DL framework has their own tensor representation and even though some efforts are being made towards cross-compatibility (example 1, example 2), they are pretty much incompatible between each other.
JDLL aims to provide a common DL Java API to be able to use any framework in a unified way, without introducing extra complexity to the developer. Being such an important part, tensor management and creation for inputs and ouputs for inference is required to be unique and intuititve too. Regard that JDLL currently only supports inference and does not allow DL network development, thus the only tensors that can be created are those representing inputs or outputs.
JDLL tensors are created by the end developer and then JDLL converts them into tensors of the specific DL framework (engine) to be able to make inference on them by the DL model. After the model is run, the engine-specific tensors are converted into general JDLL tensors. The JDLL output tensors are the one the user "receives" after running the model.
JDLL tensors are based on ImgLib2 images. The ImgLib2 images are the ones that actually contain the ordered data of the tensor. JDLL tensors are just a wrapper arround them providing useful methods that make that ImgLib2 images usable by Deep LEarning frameworks.
ImgLib2 was chosen after trying many other frameworks. The reason is because it is fast, efficient, has a great community support, is coded entirely in Java, and employs referencing rather than copying when possible, thus saving memory by avoiding duplication.
Tensor management is controlled by one class in JDLL: io.bioimage.modelrunner.tensor.Tensor
. This class and its methods are going to be detailed in this Wiki page.
This class provides the methods to create, close and manage JDLL tensors.
JDLL tensors are based in ImgLib2 images and have 3 key parts:
-
Tensor name. The name given to each tensor. It has to be unique. It is important for Tensorflow tensors, because it is used to identify the inputs and ouputs of the model, but not important for Onnx or Pytorch.
-
Tensor axes order. The order of the dimensions of the tensor. The possible dimensions are:
batch
orb
,depth
orz
,witdh
orx
,height
ory
andchannels
orc
. One possible tensor axes order for a 4D tensor isbcyx
. This parameter is equivalent to the one described in the Bioimage.io rdf.yaml file, find another explanation for it here. -
Tensor data. The actual numeric values of each of the positions of the tensor. This information is stored as an ImgLib2 image.
Tensors can be created with the data inside or empty.
Input tensors need to be created with the actual tensor data (ImgLib2 image) that is going to be processed by the DL model.
On the other hand, output tensors can be created "empty". JDLL requires that the input and output JDLL tensors have been declared in advance. But it is impossible to have the data (ImgLib2 image) for the output tensor before making inference with the model on the input tensors. This is why JDLL allows creating "empty tensors" where only the name and the axes order have already been defined.
Tensor Tensor.build( final String tensorName, final String axes, final RandomAccessibleInterval data )
Method that creates a JDLL tensor from an ImgLib2 RandomAccessibleInterval
. The number of dimensions of the RandomAccessibleInterval
has to be equal to the length of the axes
String.
This method is used to create input tensors (the data
parameter might represent an image that wants to be processed), but it can also be used to create the output empty Tensor
by using a RandomAccessibleInterval
of the dimensions, shape and data type of the expected model output tensor.
- tensorName: name of the tensor, has to be unique.
- axes: String representing the axes order of the tensor. The possible dimensions are:
batch
orb
,depth
orz
,witdh
orx
,height
ory
andchannels
orc
. One possible tensor axes order for a 4D tensor isbcyx
. This parameter is equivalent to the one described in the Bioimage.io rdf.yaml file, find another explanation for it here. - data:
RandomAccessibleInterval
containing the actual data of the tensor. The actual numeric values of each of the positions of the tensor.
In the example below, create a tensor. Note that as Img
extends RandomAccessibleInterval
, it can be used as an argument for the method:
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 1 );
Tensor tensor = Tensor.build("name", "bcyx", data);
System.out.println("Great success!")
Output:
Great success!
Method that creates an empty JDLL tensor. Empty referes that it does not contain the ImgLib2 image. It is just the name and axes of the tensor.
When running a DL model in JDLL, the information (name and axes order) required to reconstruct tensors need to be predefined in advance, both for inputs and outputs. This method was developed to allow the creation JDLL tensors witht he information required to retrieve the output but without contianing the actual DL tensor output.
Once the model is executed, the JDLL tensor is "filled" with the model output using tensor.setData(RandomAccessibleInterval data)
.
- tensorName: name of the tensor, has to be unique.
- axes: String representing the axes order of the tensor. The possible dimensions are:
batch
orb
,depth
orz
,witdh
orx
,height
ory
andchannels
orc
. One possible tensor axes order for a 4D tensor isbcyx
. This parameter is equivalent to the one described in the Bioimage.io rdf.yaml file, find another explanation for it here.
In the example below, create an empty tensor first and then the tensor is filled with the wanted ImgLib2 RandomAccessibleInterval
. Note that the tensor created with Tensor Tensor.buildEmptyTensor( final String tensorName, final String axes )
can be filled after with any ImgLib2 RandomAccessibleInterval
without any restriction.
Tensor tensor = Tensor.buildEmptyTensor("name", "bcyx");
System.out.println("Great success!")
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 1 );
tensor.setData(img1);
System.out.println("Great success 2!");
tensor.close();
Output:
Great success!
Great success 2!
Tensor Tensor.buildBlankTensor( final String tensorName, final String axes, final long[] shape, final T dtype)
Method that builds a blank tensor pre-allocating the memory that the tensor is expected to need. A blank tensor is a tensor of the shape and size of the wanted tensor but all the entries of it are 0
The reasoning behind this method is the same as for Tensor Tensor.buildEmptyTensor( final String tensorName, final String axes )
. It is useful to pre define the outputs of a DL model.
- tensorName: name of the tensor, has to be unique.
- axes: String representing the axes order of the tensor. The possible dimensions are:
batch
orb
,depth
orz
,witdh
orx
,height
ory
andchannels
orc
. One possible tensor axes order for a 4D tensor isbcyx
. This parameter is equivalent to the one described in the Bioimage.io rdf.yaml file, find another explanation for it here. ** shape: dimensions of the tensor. The length of the shape must be the same as the length of the axes string. - dtype: data type of the tensor
In the example below, create a blank tensor. Note that the tensor created will always need to have the same data type and dimensions.
Tensor tensor = Tensor.buildBlankTensor("output0", "bcyx", new long[] {1, 1, 512, 512}, new Float());
System.out.println("Great success!");
tensor.close();
Output:
Great success!
Trying to set another image as the data of the tensor results in an error if the dimensions are not the same:
Tensor tensor = Tensor.buildBlankTensor("output0", "bcyx", new long[] {1, 1, 512, 512}, new Float());
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 3, 512, 512 );
tensor.setData(img1);
Output:
IllegalArgumentException: Trying to set an array as the backend of the Tensor with a
different shape than the Tensor. Tensor shape is: [1, 1, 512, 512] and array shape is: [1, 3, 512, 512]
Creates a copy of the input Tensor
tt
in the data type of interest
** tt: tensor that wants to be copied in another datatype
** type: target data type
RandomAccessibleInterval Tensor.createCopyOfRaiInWantedDataType( final RandomAccessibleInterval input, final R type )
Creates a copy of the input RandomAccessibleInterval
in the data type of interest
** input: ImgLib2 RandomAccessibleInterval
that wants to be copied in another datatype
** type: target data type
Method to get a specific tensor by its unique name from a list of the input tensors to a model.
- lTensors: list of tensors that shuold be either the input or output list of tensors of a DL model.
- name: the name of the tensor that is wanted.
Sets the data of the tensor to the one provided by the RadomAccessibleInterval
argument.
If the Tensor
has been created using the method Tensor Tensor.buildEmptyTensor( final String tensorName, final String axes )
or if ot returns true
when calling tensor.isEmpty()
, any RandomAccessibleInterval
can be used.
On the other hand, if the method tensor.isEmpty()
returns false
the data of the RandomAccessibleInterval
has to be of the same data type as the tensor and the dimensions of the Tensor
have to coinide with the dimensions of the RandomAccessibleInterval
.
- data: a
RandomAccessibleInterval
that will contain the numerical data of the tensor.
Example below on how to use the method properly after the tensor is created with Tensor Tensor.buildEmptyTensor( final String tensorName, final String axes )
:
Tensor tensor = Tensor.buildEmptyTensor("name", "bcyx");
System.out.println("Great success!")
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 1 );
tensor.setData(img1);
System.out.println("Great success 2!");
tensor.close();
Output:
Great success!
Great success 2!
Another example showing how to use the method when the tensor is not empty:
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 3, 512, 512 );
Tensor tensor = Tensor.build("output0", "bcyx", img1);
final Img< FloatType > img2 = imgFactory.create( 1, 3, 512, 512 );
tensor.setData(img2);
tensor.close();
System.out.println("Great success!");
Output:
Great success!
MEthod showing th error caused if the tensor is not empty and the dimensions are different. Same happens if the datatype is not the same:
Tensor tensor = Tensor.buildBlankTensor("output0", "bcyx", new long[] {1, 1, 512, 512}, new Float());
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 3, 512, 512 );
tensor.setData(img1);
Output:
IllegalArgumentException: Trying to set an array as the backend of the Tensor with a
different shape than the Tensor. Tensor shape is: [1, 1, 512, 512] and array shape is: [1, 3, 512, 512]
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 3, 512, 512 );
Tensor tensor = Tensor.build("output0", "bcyx", img1);
final ImgFactory< IntType > imgIntFactory = new ArrayImgFactory<>( new IntType() );
final Img< IntType > imgInt = imgIntFactory.create( 1, 3, 512, 512 );
tensor.setData(imgInt);
Output:
IllegalArgumentException: Trying to set an array as the backend of the Tensor with a different
data type than the Tensor. Tensor data type is: float32 and array data type is: int32
Close the tensor and free all the resources attached to it.
It is always advisable to close all the tensors created once the model has been run.
Return the name of the tensor
REturn the shape of the tensor.
Get the String
that defines the axes order. The axes order String represents the axes order of the tensor. The possible dimensions are: batch
or b
, depth
or z
, witdh
or x
, height
or y
and channels
or c
. One possible tensor axes order for a 4D tensor is bcyx
. This parameter is equivalent to the one described in the Bioimage.io rdf.yaml file, find another explanation for it here.
Set whether the tensor is representing an image or other object. When Tensor
s are created it is assumed they are images. The main use is doing tensor.setImage(false)
when the tensor does not represent an image.
- isImage: whether the tensor is an image or not.
Whether a tensor has memory allocated for the tensor of a given dimensions or not. Empty tensors are created with Tensor Tensor.buildEmptyTensor( final String tensorName, final String axes )
.
The methods Tensor.build( final String tensorName, final String axes, final RandomAccessibleInterval data )
and Tensor.buildBlankTensor( final String tensorName, final String axes, final long[] shape, final T dtype)
both create non-empty tensors and allocate the memory needed for a tensor of certain dimensions and a certain data type.
Here are a couple examples of empty and non-empty tensors:
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 3, 512, 512 );
Tensor inputTensor = Tensor.build("input0", "bcyx", img1);
Tensor emptyTensor = Tensor.buildEmptyTensor("output0", "bcyx");
Tensor blankTensor = Tensor.buildBlankTensor("output1", "bcyx", new long[] {1, 1, 512, 512}, new Float());
System.out.println("Input tensor is empty: " + inputTensor.isEmpty());
System.out.println("Empty tensor is empty: " + emptyTensor.isEmpty());
System.out.println("Blank tensor is empty: " + blankTensor.isEmpty());
inputTensor.close();
emptyTensor.close();
blankTensor.close();
Output:
Input tensor is empty: false
Empty tensor is empty: true
Blank tensor is empty: false