Skip to content

Commit

Permalink
Merge pull request #1205 from Wanglongzhi2001/fix_boolean_mask
Browse files Browse the repository at this point in the history
fix: fix the bug of boolean_mask
  • Loading branch information
Oceania2018 authored Nov 4, 2023
2 parents 079b9a3 + 4e42d7f commit 44bdddc
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ public static Tensor _transpose_batch_time(Tensor x)
return x;

var x_rank = array_ops.rank(x);
var con1 = new object[]
var con1 = new Tensor[]
{
new []{1, 0 },
new Tensor(new int[]{0, 2}),
math_ops.range(2, x_rank)
};
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));
Expand Down
13 changes: 9 additions & 4 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
throw new ValueError("mask cannot be scalar.");

var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 }));
if (leading_size.rank == 0)
{
leading_size = expand_dims(leading_size, 0);
}

var shape1 = concat(new[]
{
shape(tensor_tensor)[$":{axis}"],
Expand All @@ -185,7 +190,7 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo

private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
{
var indices = squeeze(where(mask), axis: new[] { 1 });
var indices = squeeze(where_v2(mask), axis: new[] { 1 });
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis));
}

Expand Down Expand Up @@ -940,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
/// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
}

public static Tensor concat(object[] values, int axis, string name = "concat")
public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
{
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/nn_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ private static Tensor _flatten_outer_dims(Tensor logits)
new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) });

var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var output = array_ops.reshape(logits, ops);

// Set output shape if known.
Expand Down
7 changes: 4 additions & 3 deletions test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow;

namespace TensorFlowNET.UnitTest.Basics
{
Expand Down Expand Up @@ -60,14 +61,14 @@ public void batch_to_space_nd()
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
}

[TestMethod, Ignore]
[TestMethod]
public void boolean_mask()
{
if (!tf.executing_eagerly())
tf.enable_eager_execution();
var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask);
var sess = tf.Session();
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
}
}
Expand Down

0 comments on commit 44bdddc

Please sign in to comment.