Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the bug caused by concat_v2 #1208

Merged
merged 1 commit into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ public static Tensor _transpose_batch_time(Tensor x)
return x;

var x_rank = array_ops.rank(x);
var con1 = new Tensor[]
var con1 = new object[]
{
new Tensor(new int[]{0, 2}),
new []{1, 0 },
math_ops.range(2, x_rank)
};
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
/// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
}

public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
public static Tensor concat(object[] values, int axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/nn_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ private static Tensor _flatten_outer_dims(Tensor logits)
new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) });

var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var output = array_ops.reshape(logits, ops);

// Set output shape if known.
Expand Down
Loading