-
Notifications
You must be signed in to change notification settings - Fork 6
XI. JDLL tensors 0 (Summary)
JDLL implements its own agnostic tensors that act as a vehicle to communicate between the main Java software and the Java Deep Learning framework.
Thanks to the agnostic tensors the main program does not have to deal with the creation of different tensors depending on the DL framework, unifying the task.
Agnostic tensors use ImgLib2 RandomAccessibleInterval
s as the backend to store the data. ImgLib2 provides an all-in-Java fast and lightweight framework to handle the data and communicate with particular Deep Learning frameworks.
The creation of tensors in the main program side is reduced to the creation of ImgLib2 RandomAccessibleInteval
s (or objects that extend them).
Once the ImgLib2 object is created, the creation of a JDLL tensor is simple. In addition to the data as ImgLib2 it requires the name of the tensor and the axes order of the tensor (as defined in the rdf.yaml
).
An example would be:
RandomAccessibleInterval<FloatType> data = ...;
Tensor tensor = Tensor.build("name", "bcyx", data);
Note that it is also necessary to generate the agnostic tensors that correspond to the output of the model.
These tensors are going to host the results of the inference.
Output tensors can be created as empty tensors and only contain the name and axes order of the output tensor:
// Without allocation of memory
Tensor.buildEmptyTensor("outputName", "bcyx");
// Allocating memory
Tensor<FloatType> outTensor = Tensor.buildBlankTensor("output0", "bcyx",
new long[] {1, 2, 512, 512}, new FloatType());
Or can be constructed with an ImgLib2 object with the expected shape and data type of the output to allocate memory prior to execution.
RandomAccessibleInterval<FloatType> expectedData = ...;
Tensor output = Tensor.build("outputName", "bcyx", expectedData);
Once the ImgLib2 RandomAccessibleInterval
or the ImgLib2 Img
or ArrayImg
have been built by the main software
Builds a JDLL tensor using the tensorName
which is in some DL frameworks used to identify the tensor in the model, the axes
which are useful to convert the JDLL tensor into the corresponding DL framework tensor and the data stored in the data
argument.
-
tensorName
: name of the tensor as given in the model (among the currently supported DL frameworks it is only important for Tensorflow). -
axes
: String representing the order of the dimensions (axes) of the tensor. For exampleaxes = "bcyx"
means that the tensor has for dimensions. The first one,b
isbatch_size
,c
meanschannels
,y
meansheight
andx
iswidth
. For more info about the axes click here and look for the wordaxes
. -
data
: data of the tensor as an ImgLib2Img
,ArrayImg
,RandomAccessibleInterval
or similar.
JDLL requires that the information of the output tensors is provided before inference in order to correctly retrieve JDLL tensors from the corresponding DL framework tensors.
The output tensors should be created empty (blank) with or without allocation of memory.
This method creates a completely empty tensor. The backend object that would contain the tensor data is set to null
and the only information contained by the Tensor created are the name of the tensor and its axes. These two pieces of information are the minimal ones needed to build the output JDLL tensor from the DL framework tensor. This method does not allocate the memory that will be required by the output tensor.
This output tensor will be able to adapt and digest any ouptut tensor of any size or any data type as long as it fulfils the axes
contraint of the number of dimensions (it must be the same as the number of characters in the axes
String).
-
tensorName
: name of the tensor as given in the model (among the currently supported DL frameworks it is only important for Tensorflow). -
axes
: String representing the order of the dimensions (axes) of the tensor. For exampleaxes = "bcyx"
means that the tensor has for dimensions. The first one,b
isbatch_size
,c
meanschannels
,y
meansheight
andx
iswidth
. For more info about the axes click here and look for the wordaxes
.
Tensor.buildBlankTensor(final String tensorName, final String axes, final long[] shape, final T dtype)
This method creates a blank JDLL tensor. The backend contains an ImgLib2 RandomAccessible
where all the values are 0, and whicha has the dimensions specified by the argument shape
. This method does allocate the memory needed for the JDLL output tensor.
The output tensor dimensions have to be equal to shape
and the data type has to be the same as dtype
.
-
tensorName
: name of the tensor as given in the model (among the currently supported DL frameworks it is only important for Tensorflow). -
axes
: String representing the order of the dimensions (axes) of the tensor. For exampleaxes = "bcyx"
means that the tensor has for dimensions. The first one,b
isbatch_size
,c
meanschannels
,y
meansheight
andx
iswidth
. For more info about the axes click here and look for the wordaxes
. -
shape
: the dimensions of the output tensors. Its lenght must be the same as the number of characters in theaxes
String argument. -
dtype
: ImgLib2 data type that will be the data type of the tensor
The last option to build an output tensor allocating data is to create a JDLL tensor using this method as long as the input argument data
has the same dimensions and data type as the actual output tensor.
JDLL tensors can be modified.
The data contained in the ImgLib2 object inside the JDLL tensors can be modified as long as the new data has the same data type and dimensions as teh original one.
Here is one example that will work:
long[] tensorShape = new long[] {1, 2, 512, 512};
// Create the original tensor
Tensor<FloatType> outTensor = Tensor.buildBlankTensor("output0", "bcyx",
tensorShape, new FloatType());
// Create teh new ImgLib2 object with the same dims and data type
final ArrayImgFactory<FloatType> factory = new ArrayImgFactory<>(new FloatType());
final Img<FloatType> newData = factory.create(tensorShape);
outTensor.setData(newData);
Here is one example that will not work:
long[] tensorShape = new long[] {1, 2, 512, 512};
// Create the original tensor
Tensor<FloatType> outTensor = Tensor.buildBlankTensor("output0", "bcyx",
tensorShape, new FloatType());
// Create teh new ImgLib2 object with the new dims and/or new data type
long[] newTensorShape = new long[] {1, 3, 256, 256};
final ArrayImgFactory<IntType> factory = new ArrayImgFactory<>(new IntType());
final Img<IntType> newData = factory.create(newTensorShape);
outTensor.setData(newData);