Skip to content

Commit

Permalink
Update tensor_util.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
novikov-alexander authored Jun 14, 2024
1 parent 43f43eb commit b3ce158
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
TF_DataType.TF_QINT32
};

private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
{
var rows = inputArray.GetLength(0);
var cols = inputArray.GetLength(1);
var outputArray = new TOut[rows, cols];

for (var i = 0; i < rows; i++)
{
for (var j = 0; j < cols; j++)
{
outputArray[i, j] = converter(inputArray[i, j]);
}
}

return outputArray;
}

/// <summary>
/// Create a TensorProto, invoked in graph mode
/// </summary>
Expand All @@ -157,19 +174,16 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
else if(origin_dtype != dtype)
{
var new_system_dtype = dtype.as_system_dtype();
if (values is long[] long_values)
{
if (dtype == TF_DataType.TF_INT32)
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
else if (values is double[] double_values)

values = values switch
{
if (dtype == TF_DataType.TF_FLOAT)
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
else
values = Convert.ChangeType(values, new_system_dtype);

long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle),
_ => Convert.ChangeType(values, new_system_dtype),
};
dtype = values.GetDataType();
}

Expand Down

0 comments on commit b3ce158

Please sign in to comment.