Skip to content

Commit

Permalink
Merge pull request #1212 from Wanglongzhi2001/fix_validation_data_pack
Browse files Browse the repository at this point in the history
fix: fix the validation_pack when multiple input
  • Loading branch information
AsakusaRinne authored Nov 7, 2023
2 parents 94c0bb8 + 47e9019 commit ee004c0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
30 changes: 21 additions & 9 deletions src/TensorFlowNET.Core/Util/Data.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Tensorflow.NumPy;
using OneOf;
using Tensorflow.NumPy;

namespace Tensorflow.Util
{
Expand All @@ -8,10 +9,10 @@ namespace Tensorflow.Util
/// </summary>
public class ValidationDataPack
{
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;

internal OneOf<NDArray, NDArray[]> val_x;
internal NDArray val_y;
internal NDArray val_sample_weight = null;
public bool val_x_is_array = false;
public ValidationDataPack((NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
Expand All @@ -27,15 +28,17 @@ public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)

public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2;
val_x_is_array = true;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
val_x_is_array = true;
}

public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
Expand All @@ -52,15 +55,24 @@ public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArra

public void Deconstruct(out NDArray val_x, out NDArray val_y)
{
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y;
}

public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}

// add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse)
{
val_x_array = this.val_x.AsT1;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
unuse = null;
}
}
}
14 changes: 11 additions & 3 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,17 @@ public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) tra
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);

ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (val_x, val_y);
}
return ((train_x, train_y, sample_weight), validation_data);
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
return evaluate(data_handler, callbacks, is_val, test_function);
}

public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
public Dictionary<string, float> evaluate(
IEnumerable<Tensor> x,
Tensor y,
int verbose = 1,
NDArray sample_weight = null,
bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x.ToArray()),
Y = y,
Model = this,
SampleWeight = sample_weight,
StepsPerExecution = _steps_per_execution
});

Expand Down
23 changes: 19 additions & 4 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Util;
using OneOf;

namespace Tensorflow.Keras.Engine
{
Expand Down Expand Up @@ -287,10 +288,24 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal

if (validation_data != null)
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var (val_x, val_y, val_sample_weight) = validation_data;
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
NDArray val_x;
NDArray[] val_x_array;
NDArray val_y;
NDArray val_sample_weight;
Dictionary<string, float> val_logs;
if (!validation_data.val_x_is_array)
{
(val_x, val_y, val_sample_weight) = validation_data;
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true);

}
else
{
(val_x_array, val_y, val_sample_weight, _) = validation_data;
val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true);
}
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
Expand Down

0 comments on commit ee004c0

Please sign in to comment.