diff --git a/TensorFlowSharp/AdditionalAssemblyInfo.cs b/TensorFlowSharp/AdditionalAssemblyInfo.cs new file mode 100644 index 00000000..77f662b5 --- /dev/null +++ b/TensorFlowSharp/AdditionalAssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("TensorFlowSharp.Tests.CSharp")] \ No newline at end of file diff --git a/TensorFlowSharp/Foundation/RangeChecks.cs b/TensorFlowSharp/Foundation/RangeChecks.cs new file mode 100644 index 00000000..f47cb2ec --- /dev/null +++ b/TensorFlowSharp/Foundation/RangeChecks.cs @@ -0,0 +1,25 @@ +using System; +using System.Runtime.CompilerServices; + +namespace TensorFlowSharp.Foundation +{ + internal static class RangeChecks + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsWithinRange(int value, int minInclusive, int maxExclusive) => minInclusive <= value && value < maxExclusive; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsWithinRange(T obj, T minInclusive, T maxExcluded) + where T : IComparable + { + return minInclusive.CompareTo(obj) <= 0 && 0 < maxExcluded.CompareTo(obj); + } + + public static int Assert(int value, string name, int minInclusive, int maxExclusive) + { + return IsWithinRange(value, minInclusive, maxExclusive) + ? value + : throw new ArgumentOutOfRangeException(FormattableString.Invariant($"Expected '{name}' to be within [{minInclusive}, {maxExclusive}), but was {value}")); + } + } +} diff --git a/TensorFlowSharp/Foundation/TFSharpDebug.cs b/TensorFlowSharp/Foundation/TFSharpDebug.cs new file mode 100644 index 00000000..18205ae0 --- /dev/null +++ b/TensorFlowSharp/Foundation/TFSharpDebug.cs @@ -0,0 +1,38 @@ +using System; +using System.Runtime.CompilerServices; + +namespace TensorFlowSharp.Foundation +{ + internal delegate void DebugAction(); + internal delegate T DebugFunc(); + + internal static class TFSharpDebug + { + private const bool IsDebug = +#if DEBUG + true; +#else + false; +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T AssertNotNull(T obj, string name) + where T : class + { + return IsDebug + ? obj ?? throw new ArgumentNullException(name) + : obj; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T AssertWithinLimits(T obj, string name, T minInclusive, T maxExcluded) + where T : IComparable + { + + if (!IsDebug || RangeChecks.IsWithinRange(obj, minInclusive, maxExcluded)) + return obj; + else + throw new ArgumentOutOfRangeException(FormattableString.Invariant($"Expected '{name}' to be between [{minInclusive}, {maxExcluded})")); + } + } +} diff --git a/TensorFlowSharp/Tensorflow.cs b/TensorFlowSharp/Tensorflow.cs index ec14906e..f92daa9c 100644 --- a/TensorFlowSharp/Tensorflow.cs +++ b/TensorFlowSharp/Tensorflow.cs @@ -37,6 +37,7 @@ using System.Numerics; using System.Collections.Generic; using System.Linq.Expressions; +using TensorFlowSharp.Foundation; namespace TensorFlow { @@ -48,10 +49,15 @@ static partial class NativeBinding internal static string GetStr (this IntPtr x) => Marshal.PtrToStringAnsi (x); } - /// - /// Contains TensorFlow fundamental methods and utility functions. - /// - public static class TFCore + internal static class TFNative + { + public static IntPtr Nullptr => IntPtr.Zero; + } + + /// + /// Contains TensorFlow fundamental methods and utility functions. + /// + public static class TFCore { internal static bool UseCPU = true; @@ -178,6 +184,7 @@ public abstract class TFDisposable : IDisposable /// /// The handle. public IntPtr Handle => handle; + internal bool HandleIsDisposed => handle == IntPtr.Zero; static TFDisposable () { @@ -235,6 +242,12 @@ public virtual void Dispose (bool disposing) } } + internal void AssertNotDisposed() + { + if (HandleIsDisposed) + ObjectDisposedException(); + } + internal static void ObjectDisposedException () { throw new ObjectDisposedException ("The object was disposed"); @@ -501,9 +514,7 @@ internal override void NativeDispose (IntPtr handle) /// public void SetTarget (string target) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - + AssertNotDisposed(); TF_SetTarget (handle, target); } @@ -522,8 +533,7 @@ public void SetTarget (string target) /// public void SetConfig (IntPtr protoData, int length, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); + AssertNotDisposed(); var cstatus = TFStatus.Setup (status); @@ -578,9 +588,9 @@ internal override void NativeDispose (IntPtr handle) // extern void TF_GraphSetTensorShape (TF_Graph *graph, TF_Output output, const int64_t *dims, const int num_dims, TF_Status *status); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TFOutput output, long [] dims, int num_dims, TF_Status status); + static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TF_Output output, long [] dims, int num_dims, TF_Status status); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TFOutput output, IntPtr dims, int num_dims, TF_Status status); + static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TF_Output output, IntPtr dims, int num_dims, TF_Status status); /// /// Sets the tensor shape of the tensor referenced by to the shape described by . @@ -590,20 +600,19 @@ internal override void NativeDispose (IntPtr handle) /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void SetTensorShape (TFOutput output, long [] dims, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); + AssertNotDisposed(); - var cstatus = TFStatus.Setup (status); + var cstatus = TFStatus.Setup (status); if (dims == null) - TF_GraphSetTensorShape (handle, output, IntPtr.Zero, 0, cstatus.handle); + TF_GraphSetTensorShape(handle, output.NativeRepresentation, TFNative.Nullptr, 0, cstatus.handle); else - TF_GraphSetTensorShape (handle, output, dims, dims.Length, cstatus.handle); + TF_GraphSetTensorShape (handle, output.NativeRepresentation, dims, dims.Length, cstatus.handle); cstatus.CheckMaybeRaise (status); } // extern int TF_GraphGetTensorNumDims (TF_Graph *graph, TF_Output output, TF_Status *status); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe int TF_GraphGetTensorNumDims (TF_Graph graph, TFOutput output, TF_Status status); + static extern unsafe int TF_GraphGetTensorNumDims (TF_Graph graph, TF_Output output, TF_Status status); /// /// Returns the number of dimensions of the Tensor referenced by output @@ -613,17 +622,17 @@ public void SetTensorShape (TFOutput output, long [] dims, TFStatus status = nul /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public int GetTensorNumDims (TFOutput output, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); - var code = TF_GraphGetTensorNumDims (handle, output, cstatus.handle); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); + var code = TF_GraphGetTensorNumDims (handle, output.NativeRepresentation, cstatus.handle); cstatus.CheckMaybeRaise (status); return code; } // extern void TF_GraphGetTensorShape (TF_Graph *graph, TF_Output output, int64_t *dims, int num_dims, TF_Status *status); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_GraphGetTensorShape (TF_Graph graph, TFOutput output, long [] dims, int num_dims, TF_Status status); + static extern unsafe void TF_GraphGetTensorShape (TF_Graph graph, TF_Output output, long [] dims, int num_dims, TF_Status status); /// /// Returns the shape of a tensor specified in . @@ -634,17 +643,17 @@ public int GetTensorNumDims (TFOutput output, TFStatus status = null) /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public TFShape GetTensorShape (TFOutput output, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); - var n = TF_GraphGetTensorNumDims (handle, output, cstatus.handle); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); + var n = TF_GraphGetTensorNumDims (handle, output.NativeRepresentation, cstatus.handle); if (!cstatus.CheckMaybeRaise (status, last: false)) return TFShape.Unknown; if (n == -1) return TFShape.Unknown; var dims = new long [n]; - TF_GraphGetTensorShape (handle, output, dims, dims.Length, cstatus.handle); + TF_GraphGetTensorShape (handle, output.NativeRepresentation, dims, dims.Length, cstatus.handle); cstatus.CheckMaybeRaise (status); return new TFShape (dims); } @@ -660,9 +669,9 @@ public TFShape GetTensorShape (TFOutput output, TFStatus status = null) /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (outputGraphDef == null) + AssertNotDisposed(); + + if (outputGraphDef == null) throw new ArgumentNullException (nameof (outputGraphDef)); var cstatus = TFStatus.Setup (status); @@ -686,9 +695,9 @@ public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null) /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (graphDef == null) + AssertNotDisposed(); + + if (graphDef == null) throw new ArgumentNullException (nameof (graphDef)); if (prefix == null) throw new ArgumentNullException (nameof (prefix)); @@ -708,9 +717,9 @@ public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = nul /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (graphDef == null) + AssertNotDisposed(); + + if (graphDef == null) throw new ArgumentNullException (nameof (graphDef)); if (options == null) throw new ArgumentNullException (nameof (options)); @@ -732,9 +741,9 @@ public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void Import (byte [] buffer, string prefix = "", TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (buffer == null) + AssertNotDisposed(); + + if (buffer == null) throw new ArgumentNullException (nameof (buffer)); if (prefix == null) throw new ArgumentNullException (nameof (prefix)); @@ -756,9 +765,9 @@ public void Import (byte [] buffer, string prefix = "", TFStatus status = null) /// public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (buffer == null) + AssertNotDisposed(); + + if (buffer == null) throw new ArgumentNullException (nameof (buffer)); if (options == null) throw new ArgumentNullException (nameof (options)); @@ -777,16 +786,18 @@ public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus st /// Gets the with the specified name, or null if the named operation does not exist in the graph. /// /// Name to lookup. - public TFOperation this [string name] { - get { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var h = TF_GraphOperationByName (handle, name); - if (h == IntPtr.Zero) - return null; - return new TFOperation (this, h); - } - } + public TFOperation this[string name] + { + get + { + AssertNotDisposed(); + + var h = TF_GraphOperationByName(handle, name); + if (h == IntPtr.Zero) + return null; + return new TFOperation(this, h); + } + } // extern TF_Operation * TF_GraphNextOperation (TF_Graph *graph, size_t *pos); [DllImport (NativeBinding.TensorFlowLibrary)] @@ -798,9 +809,9 @@ public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus st /// The enumerator. public IEnumerable GetEnumerator () { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - IntPtr token = IntPtr.Zero; + AssertNotDisposed(); + + IntPtr token = IntPtr.Zero; IntPtr operll; while ((operll = TF_GraphNextOperation (handle, ref token)) != IntPtr.Zero) @@ -815,17 +826,17 @@ public IEnumerable GetEnumerator () /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public long [] GetShape (TFOutput output, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); - var ndims = TF_GraphGetTensorNumDims (handle, output, cstatus.handle); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); + var ndims = TF_GraphGetTensorNumDims (handle, output.NativeRepresentation, cstatus.handle); if (!cstatus.CheckMaybeRaise (status, last: false)) return null; if (ndims == 0) return null; var ret = new long [ndims]; - TF_GraphGetTensorShape (handle, output, ret, ndims, cstatus.handle); + TF_GraphGetTensorShape (handle, output.NativeRepresentation, ret, ndims, cstatus.handle); cstatus.CheckMaybeRaise (status); return ret; } @@ -923,7 +934,7 @@ internal int GetNextId () } [DllImport (NativeBinding.TensorFlowLibrary)] - unsafe static extern void TF_GraphImportGraphDefWithReturnOutputs (TF_Graph graph, LLBuffer *graph_def, TF_ImportGraphDefOptions options, TFOutput *return_outputs, int num_return_outputs, TF_Status status); + unsafe static extern void TF_GraphImportGraphDefWithReturnOutputs (TF_Graph graph, LLBuffer *graph_def, TF_ImportGraphDefOptions options, TF_Output *return_outputs, int num_return_outputs, TF_Status status); /// /// Imports a graph serialized into the graph @@ -935,11 +946,11 @@ internal int GetNextId () /// /// If you are tryig to load a file stored using the SavedModel file format, you should use the API instead. /// - public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, TFOutput [] returnOutputs, TFStatus status = null) + public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, TFOutput[] returnOutputs, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (graphDef == null) + AssertNotDisposed(); + + if (graphDef == null) throw new ArgumentNullException (nameof (graphDef)); if (options == null) throw new ArgumentNullException (nameof (options)); @@ -947,13 +958,19 @@ public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, unsafe { - if (returnOutputs == null) { + if (returnOutputs == null) + { TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, null, 0, cstatus.handle); - } else { - fixed (TFOutput* first = &returnOutputs [0]) + } + else + { + var nativeReturnOutputs = new TF_Output[returnOutputs.Length]; + fixed (TF_Output* outputs = nativeReturnOutputs) { - TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, first, returnOutputs.Length, cstatus.handle); + TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, outputs, nativeReturnOutputs.Length, cstatus.handle); } + + TFOutput.FromNative(nativeReturnOutputs, returnOutputs, this); } } } @@ -963,28 +980,28 @@ unsafe struct TFWhileParams { public int ninputs; public TF_Graph cond_graph; - public TFOutput* cond_inputs; - public TFOutput cond_output; + public TF_Output* cond_inputs; + public TF_Output cond_output; public TF_Graph body_graph; - public TFOutput* body_inputs; - public TFOutput* body_outputs; + public TF_Output* body_inputs; + public TF_Output* body_outputs; public IntPtr charPtrName; } [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe TFWhileParams TF_NewWhile (TF_Graph g, TFOutput [] inputs, int ninputs, TF_Status status); + static extern unsafe TFWhileParams TF_NewWhile (TF_Graph g, TF_Output[] inputs, int ninputs, TF_Status status); [DllImport (NativeBinding.TensorFlowLibrary)] static extern void TF_AbortWhile (ref TFWhileParams pars); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_FinishWhile (ref TFWhileParams pars, TF_Status status, TFOutput *outputs); + static extern unsafe void TF_FinishWhile (ref TFWhileParams pars, TF_Status status, TF_Output* outputs); - static unsafe TFOutput [] CopyFrom (TFOutput* ptr, int n) + static unsafe TFOutput[] CopyFrom (TF_Output* ptr, int n, TFGraph graph) { - var r = new TFOutput [n]; - for (int i = 0; i < n; i++) - r [i] = ptr [i]; + var r = new TFOutput[n]; + for (int i = 0; i < n; i++) + r[i] = new TFOutput(ptr[i], graph); return r; } @@ -1025,14 +1042,16 @@ unsafe struct TFWhileParams /// public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (inputs == null) + AssertNotDisposed(); + + if (inputs == null) throw new ArgumentNullException (nameof (inputs)); if (constructor == null) throw new ArgumentNullException (nameof (constructor)); var cstatus = TFStatus.Setup (status); - TFWhileParams result = TF_NewWhile (handle, inputs, inputs.Length, cstatus.handle); + + var nativeInputs = TFOutput.ToNative(inputs); + TFWhileParams result = TF_NewWhile (handle, nativeInputs, nativeInputs.Length, cstatus.handle); if (cstatus.Error) return null; @@ -1051,7 +1070,9 @@ unsafe struct TFWhileParams { var condGraph = new TFGraphUnowned (result.cond_graph); var bodyGraph = new TFGraphUnowned (result.body_graph); - constructor (condGraph, CopyFrom (result.cond_inputs, n), out result.cond_output, bodyGraph, CopyFrom (result.body_inputs, n), bodyOutputs, out name); + constructor (condGraph, CopyFrom (result.cond_inputs, n, this), out var conditionalOutput, bodyGraph, CopyFrom (result.body_inputs, n, this), bodyOutputs, out name); + //TODO: + result.cond_output = conditionalOutput.NativeRepresentation; } if (name == null || name == "") name = MakeUnique ("while"); @@ -1064,16 +1085,15 @@ unsafe struct TFWhileParams unsafe { - for (int i = 0; i < n; i++) - result.body_outputs [i] = bodyOutputs [i]; - var ret = new TFOutput [inputs.Length]; - fixed (TFOutput* first = &ret [0]) + for (int i = 0; i < n; i++) + result.body_outputs[i] = bodyOutputs[i].NativeRepresentation; + var nativeResults = new TF_Output[inputs.Length]; + fixed (TF_Output* first = nativeResults) TF_FinishWhile (ref result, cstatus.handle, first); - - if (cstatus.CheckMaybeRaise (status)) - return ret; - } + if (cstatus.CheckMaybeRaise (status)) + return TFOutput.FromNative(nativeResults, this); + } return null; } catch { TF_AbortWhile (ref result); @@ -1082,57 +1102,61 @@ unsafe struct TFWhileParams } [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_AddGradients (TF_Graph graph, TFOutput* ys, int ny, TFOutput* xs, int nx, TFOutput* dx, TF_Status status, TFOutput* dy); - - /// - /// Adds a gradient: the operations needed to compute the partial derivatives of sum of ` wrt to . - /// - /// The partial derivatives, the size of the array is the same as the length of the array. - /// The y elements. - /// The x elements. - /// Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. ). - /// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in - /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. - /// - /// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z... - /// - public TFOutput [] AddGradients (TFOutput [] y, TFOutput [] x, TFOutput [] dx = null, TFStatus status = null) - { - if (y == null) - throw new ArgumentNullException (nameof (y)); - if (x == null) - throw new ArgumentNullException (nameof (x)); - if (dx != null) { - if (dx.Length != y.Length) - throw new ArgumentException ("If dx is not null, the size of the gradients must match the size of y", nameof (dx)); - } - - var cstatus = TFStatus.Setup (status); - - var ret = new TFOutput [x.Length]; - unsafe - { - fixed (TFOutput* pret = &ret [0]) { - fixed (TFOutput* py = &y [0]) { - fixed (TFOutput* px = &x [0]) { - if (dx == null) { - TF_AddGradients (handle, py, y.Length, px, x.Length, (TFOutput*)null, cstatus.Handle, pret); - } else { - fixed (TFOutput* pdx = &dx [0]) { - TF_AddGradients (handle, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret); - } - } - } - } - } - } - if (!cstatus.CheckMaybeRaise (status, last: false)) - return null; - return ret; - } - - [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_AddGradientsWithPrefix (TF_Graph graph, string prefix, TFOutput* ys, int ny, TFOutput* xs, int nx, TFOutput* dx, TF_Status status, TFOutput* dy); + static extern unsafe void TF_AddGradients (TF_Graph graph, TF_Output* ys, int ny, TF_Output* xs, int nx, TF_Output* dx, TF_Status status, TF_Output* dy); + + /// + /// Adds a gradient: the operations needed to compute the partial derivatives of sum of ` wrt to . + /// + /// The partial derivatives, the size of the array is the same as the length of the array. + /// The y elements. + /// The x elements. + /// Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. ). + /// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z... + /// + public TFOutput[] AddGradients(TFOutput[] y, TFOutput[] x, TFOutput[] dx = null, TFStatus status = null) + { + AssertNotDisposed(); + + if (y == null) + throw new ArgumentNullException(nameof(y)); + if (x == null) + throw new ArgumentNullException(nameof(x)); + if (dx != null) + { + if (dx.Length != y.Length) + throw new ArgumentException("If dx is not null, the size of the gradients must match the size of y", nameof(dx)); + } + + var cstatus = TFStatus.Setup(status); + + var nresult = new TF_Output[x.Length]; + unsafe + { + fixed (TF_Output* pret = nresult, py = TFOutput.ToNative(y), px = TFOutput.ToNative(x)) + { + if (dx == null) + { + TF_AddGradients(handle, py, y.Length, px, x.Length, (TF_Output*)null, cstatus.Handle, pret); + } + else + { + fixed (TF_Output* pdx = TFOutput.ToNative(dx)) + { + TF_AddGradients(handle, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret); + } + } + } + } + if (!cstatus.CheckMaybeRaise(status, last: false)) + return null; + return TFOutput.FromNative(nresult, this); + } + + [DllImport (NativeBinding.TensorFlowLibrary)] + static extern unsafe void TF_AddGradientsWithPrefix (TF_Graph graph, string prefix, TF_Output* ys, int ny, TF_Output* xs, int nx, TF_Output* dx, TF_Status status, TF_Output* dy); /// /// Adds a gradient: the operations needed to compute the partial derivatives of sum of ` wrt to . /// @@ -1151,7 +1175,9 @@ unsafe struct TFWhileParams /// public TFOutput [] AddGradients (string prefix, TFOutput [] y, TFOutput [] x, TFOutput [] dx = null, TFStatus status = null) { - if (y == null) + AssertNotDisposed(); + + if (y == null) throw new ArgumentNullException (nameof (y)); if (x == null) throw new ArgumentNullException (nameof (x)); @@ -1162,146 +1188,164 @@ unsafe struct TFWhileParams var cstatus = TFStatus.Setup (status); - var ret = new TFOutput [x.Length]; - unsafe { - fixed (TFOutput* pret = &ret [0]) { - fixed (TFOutput* py = &y [0]) { - fixed (TFOutput* px = &x [0]) { - if (dx == null) { - TF_AddGradientsWithPrefix (handle, prefix, py, y.Length, px, x.Length, (TFOutput*)null, cstatus.Handle, pret); - } else { - fixed (TFOutput* pdx = &dx [0]) { - TF_AddGradientsWithPrefix (handle, prefix, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret); - } - } - } - } - } - } + var nresult = new TF_Output [x.Length]; + unsafe + { + fixed (TF_Output* pret = nresult, py = TFOutput.ToNative(y), px = TFOutput.ToNative(x)) + { + if (dx == null) + { + TF_AddGradientsWithPrefix(handle, prefix, py, y.Length, px, x.Length, (TF_Output*)null, cstatus.Handle, pret); + } + else + { + fixed (TF_Output* pdx = TFOutput.ToNative(dx)) + { + TF_AddGradientsWithPrefix(handle, prefix, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret); + } + } + } + } if (!cstatus.CheckMaybeRaise (status, last: false)) return null; - return ret; + return TFOutput.FromNative(nresult, this); } [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe void TF_GraphCopyFunction (TF_Graph graph, TF_Function func, TF_Function grad, TF_Status status); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe IntPtr TF_GraphToFunction (TF_Graph body, string fn_name, byte append_hash_to_fn_name, int num_opers, IntPtr opers, int ninputs, TFOutput [] inputs, int noutputs, TFOutput [] ouputs, string [] output_names, IntPtr options, string description, TF_Status status); - - /// - /// Creates a TFFunction from a TFGraph - /// - /// The function. - /// Name of the new function. Should match the operation name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. If appendHashToFunctioName is false, the name must be unique (at least those registered in graphs where this function will be used). - /// Optional, human readable description of this function. - /// Array of operations to become the body of the function or null. - /// If no array is given , all the - /// operations in function body will become part of the function - /// except operations referenced in inputs. These operations - /// must have a single output (these operations are typically - /// placeholders created for the sole purpose of representing - /// an input). - /// - /// If an array is given, all operations - /// in it will become part of the function. In particular, no - /// automatic skipping of dummy input operations is performed. - /// - /// Array that specify the inputs to the function, or null. The names used for function inputs are normalized - /// names of the operations (usually placeholders) pointed to by - /// inputs. These operation names should start with a letter. - /// Normalization will convert all letters to lowercase and - /// non-alphanumeric characters to '_' to make resulting names match - /// the "[a-z][a-z0-9_]*" pattern for operation argument names. - /// `inputs` cannot contain the same tensor twice. - /// rray that specify the inputs to the function, or null. This can contain the same tensor twice. - /// The names of the function's outputs. The array either has the same elements of outputs, or be null. Names must match "[a-z][a-z0-9_]*" regexp, if null is passed, the names are generated automatically. - /// If set to true appends hash to functionName, otherwise it will use the specified name in functionName. - /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. - /// - /// - /// This method converts the graph whose operations (or a subset of its operations) will be converted - /// into a TFFunction. - /// - /// - /// Note that when the same TF_Output is listed as both an input and an output, - /// the corresponding function's output will equal to this input, - /// instead of the original node's output. - /// - /// - /// Callers must also satisfy the following constraints: - /// - /// - /// cannot refer to TFOutputs within a control flow context. For - /// example, one cannot use the output of "switch" node as input. - /// - /// - /// and cannot have reference types. Reference types are - /// not exposed through C API and are being replaced with Resources. We support - /// reference types inside function's body to support legacy code. Do not - /// use them in new code. - /// - /// - /// Every node in the function's body must have all of its inputs (including - /// control inputs). In other words, for every node in the body, each input - /// must be either listed in or must come from another node in - /// the body. In particular, it is an error to have a control edge going from - /// a node outside of the body into a node in the body. This applies to control - /// edges going from nodes referenced in to nodes in the body when - /// the former nodes are not in the body (automatically skipped or not - /// included in explicitly specified body). - /// - /// - public TFFunction ToFunction (string functionName, - string description, - TFOperation [] operations, - TFOutput [] inputs, - TFOutput [] outputs, - string [] outputNames, - bool appendHashToFunctionName = false, - TFStatus status = null) - { - if (functionName == null) - throw new ArgumentNullException (nameof (functionName)); - if (outputs == null) { - if (outputNames != null) - throw new ArgumentException ("outputs is null, but outputNames is not", nameof (outputNames)); - } else { - if (outputNames != null && outputs.Length != outputNames.Length) - throw new ArgumentException ("the outputs and outputNames array are specified, but have different lenghts"); - } - var cstatus = TFStatus.Setup (status); - - unsafe { - IntPtr functionOptions = IntPtr.Zero; - IntPtr ops = IntPtr.Zero; - int nops; - if (operations == null) { - nops = 0; - ops = IntPtr.Zero; - } else { - nops = operations.Length; - ops = Marshal.AllocHGlobal (sizeof (IntPtr) * operations.Length); - for (int i = 0; i < nops; i++) - Marshal.WriteIntPtr (ops, i * sizeof (IntPtr), operations [i].handle); - } - - var fnHandle = TF_GraphToFunction (handle, functionName, (byte) (appendHashToFunctionName ? 1 : 0), - nops, ops, - inputs == null ? 0 : inputs.Length, inputs, - outputs == null ? 0 : outputs.Length, outputs, - outputNames, - functionOptions, - description, - cstatus.Handle); - if (ops != IntPtr.Zero) - Marshal.FreeHGlobal (ops); - - if (!cstatus.CheckMaybeRaise (status, last: false)) - return null; - return new TFFunction (fnHandle); - } - } + static extern unsafe IntPtr TF_GraphToFunction (TF_Graph body, string fn_name, byte append_hash_to_fn_name, int num_opers, IntPtr opers, int ninputs, TF_Output [] inputs, int noutputs, TF_Output[] ouputs, string [] output_names, IntPtr options, string description, TF_Status status); + + /// + /// Creates a TFFunction from a TFGraph + /// + /// The function. + /// Name of the new function. Should match the operation name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. If appendHashToFunctioName is false, the name must be unique (at least those registered in graphs where this function will be used). + /// Optional, human readable description of this function. + /// Array of operations to become the body of the function or null. + /// If no array is given , all the + /// operations in function body will become part of the function + /// except operations referenced in inputs. These operations + /// must have a single output (these operations are typically + /// placeholders created for the sole purpose of representing + /// an input). + /// + /// If an array is given, all operations + /// in it will become part of the function. In particular, no + /// automatic skipping of dummy input operations is performed. + /// + /// Array that specify the inputs to the function, or null. The names used for function inputs are normalized + /// names of the operations (usually placeholders) pointed to by + /// inputs. These operation names should start with a letter. + /// Normalization will convert all letters to lowercase and + /// non-alphanumeric characters to '_' to make resulting names match + /// the "[a-z][a-z0-9_]*" pattern for operation argument names. + /// `inputs` cannot contain the same tensor twice. + /// rray that specify the inputs to the function, or null. This can contain the same tensor twice. + /// The names of the function's outputs. The array either has the same elements of outputs, or be null. Names must match "[a-z][a-z0-9_]*" regexp, if null is passed, the names are generated automatically. + /// If set to true appends hash to functionName, otherwise it will use the specified name in functionName. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// + /// This method converts the graph whose operations (or a subset of its operations) will be converted + /// into a TFFunction. + /// + /// + /// Note that when the same TF_Output is listed as both an input and an output, + /// the corresponding function's output will equal to this input, + /// instead of the original node's output. + /// + /// + /// Callers must also satisfy the following constraints: + /// + /// + /// cannot refer to TFOutputs within a control flow context. For + /// example, one cannot use the output of "switch" node as input. + /// + /// + /// and cannot have reference types. Reference types are + /// not exposed through C API and are being replaced with Resources. We support + /// reference types inside function's body to support legacy code. Do not + /// use them in new code. + /// + /// + /// Every node in the function's body must have all of its inputs (including + /// control inputs). In other words, for every node in the body, each input + /// must be either listed in or must come from another node in + /// the body. In particular, it is an error to have a control edge going from + /// a node outside of the body into a node in the body. This applies to control + /// edges going from nodes referenced in to nodes in the body when + /// the former nodes are not in the body (automatically skipped or not + /// included in explicitly specified body). + /// + /// + public TFFunction ToFunction(string functionName, + string description, + TFOperation[] operations, + TFOutput[] inputs, + TFOutput[] outputs, + string[] outputNames, + bool appendHashToFunctionName = false, + TFStatus status = null) + { + AssertNotDisposed(); + + if (functionName == null) + throw new ArgumentNullException(nameof(functionName)); + if (outputs == null) + { + if (outputNames != null) + throw new ArgumentException("outputs is null, but outputNames is not", nameof(outputNames)); + } + else + { + if (outputNames != null && outputs.Length != outputNames.Length) + throw new ArgumentException("the outputs and outputNames array are specified, but have different lenghts"); + } + var cstatus = TFStatus.Setup(status); + + unsafe + { + var functionOptions = TFNative.Nullptr; + var ops = TFNative.Nullptr; + int nops; + if (operations == null) + { + nops = 0; + ops = TFNative.Nullptr; + } + else + { + nops = operations.Length; + ops = Marshal.AllocHGlobal(sizeof(IntPtr) * operations.Length); + for (int i = 0; i < nops; i++) + Marshal.WriteIntPtr(ops, i * sizeof(IntPtr), operations[i].handle); + } + + var fnHandle = TF_GraphToFunction( + handle, + functionName, + (byte)(appendHashToFunctionName ? 1 : 0), + nops, + ops, + inputs == null ? 0 : inputs.Length, + TFOutput.ToNative(inputs), + outputs == null ? 0 : outputs.Length, + TFOutput.ToNative(outputs), + outputNames, + functionOptions, + description, + cstatus.Handle); + + if (ops != TFNative.Nullptr) + Marshal.FreeHGlobal(ops); + + if (!cstatus.CheckMaybeRaise(status, last: false)) + return null; + return new TFFunction(fnHandle); + } + } [DllImport (NativeBinding.TensorFlowLibrary)] unsafe static extern void TF_GraphVersions (TF_Graph graph, LLBuffer *output_version_def, TF_Status status); @@ -1314,9 +1358,9 @@ unsafe struct TFWhileParams /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public void Versions (TFBuffer outputVersionDef, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (outputVersionDef == null) + AssertNotDisposed(); + + if (outputVersionDef == null) throw new ArgumentNullException (nameof (outputVersionDef)); var cstatus = TFStatus.Setup (status); @@ -1343,10 +1387,11 @@ public void Versions (TFBuffer outputVersionDef, TFStatus status = null) /// /// The functions. public TFFunction [] Functions { - get { - if (handle == null) - ObjectDisposedException (); - var n = NumFunctions; + get + { + AssertNotDisposed(); + + var n = NumFunctions; unsafe { TFFunction [] ret = null; var size = sizeof (IntPtr); @@ -1370,7 +1415,7 @@ public void Versions (TFBuffer outputVersionDef, TFStatus status = null) } [DllImport (NativeBinding.TensorFlowLibrary)] - static extern bool TF_TryEvaluateConstant (TF_Graph graph, TFOutput output, ref IntPtr result, TF_Status status); + static extern bool TF_TryEvaluateConstant (TF_Graph graph, TF_Output output, ref TF_Tensor result, TF_Status status); /// /// Attempts to evaluate the . This is only possible if does not @@ -1381,9 +1426,11 @@ public void Versions (TFBuffer outputVersionDef, TFStatus status = null) /// Tensor. public bool TryEvaluateConstant (TFOutput output, ref TFTensor tensor) { - var cstatus = new TFStatus (); - IntPtr ptr = IntPtr.Zero; - var ret = TF_TryEvaluateConstant (handle, output, ref ptr, cstatus.handle); + AssertNotDisposed(); + + var cstatus = new TFStatus (); + TF_Tensor ptr = TFNative.Nullptr; + var ret = TF_TryEvaluateConstant (handle, output.NativeRepresentation, ref ptr, cstatus.handle); cstatus.Dispose (); if (ret) @@ -1617,9 +1664,9 @@ internal override void NativeDispose (IntPtr handle) /// The device to constraint to in this operation. public TFOperationDesc SetDevice (string device) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (device == null) + AssertNotDisposed(); + + if (device == null) throw new ArgumentNullException ("device"); TF_SetDevice (handle, device); return this; @@ -1627,7 +1674,7 @@ public TFOperationDesc SetDevice (string device) // extern void TF_AddInput (TF_OperationDescription *desc, TF_Output input); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_AddInput (TF_OperationDescription desc, TFOutput input); + static extern unsafe void TF_AddInput (TF_OperationDescription desc, TF_Output input); /// /// Adds the specified input to the operation @@ -1636,15 +1683,15 @@ public TFOperationDesc SetDevice (string device) /// Input. public TFOperationDesc AddInput (TFOutput input) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - TF_AddInput (handle, input); + AssertNotDisposed(); + + TF_AddInput(handle, input.NativeRepresentation); return this; } // extern void TF_AddInputList (TF_OperationDescription *desc, const TF_Output *inputs, int num_inputs); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_AddInputList (TF_OperationDescription desc, TFOutput [] inputs, int num_inputs); + static extern unsafe void TF_AddInputList (TF_OperationDescription desc, TF_Output [] inputs, int num_inputs); /// /// Adds a series of inputs to the operation. @@ -1652,12 +1699,12 @@ public TFOperationDesc AddInput (TFOutput input) /// Inputs, this is a params array for your convenience. public TFOperationDesc AddInputs (params TFOutput [] inputs) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (inputs == null || inputs.Length == 0) + AssertNotDisposed(); + + if (inputs == null || inputs.Length == 0) return this; - TF_AddInputList (handle, inputs, inputs.Length); + TF_AddInputList(handle, TFOutput.ToNative(inputs), inputs.Length); return this; } @@ -1682,9 +1729,9 @@ public TFOperationDesc AddInputs (params TFOutput [] inputs) /// public TFOperationDesc AddControlInput (TFOperation control) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (control == null) + AssertNotDisposed(); + + if (control == null) throw new ArgumentNullException ("input"); TF_AddControlInput (handle, control.handle); @@ -1697,9 +1744,9 @@ public TFOperationDesc AddControlInput (TFOperation control) public TFOperationDesc ColocateWith (TFOperation op) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (op == null) + AssertNotDisposed(); + + if (op == null) throw new ArgumentNullException ("op"); TF_ColocateWith (handle, op.handle); return this; @@ -1711,9 +1758,9 @@ public TFOperationDesc ColocateWith (TFOperation op) public TFOperationDesc SetAttr (string attrName, string value) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); var bytes = Encoding.UTF8.GetBytes (value); var buf = Marshal.AllocHGlobal (bytes.Length + 1); @@ -1728,9 +1775,9 @@ public TFOperationDesc SetAttr (string attrName, string value) static extern unsafe void TF_SetAttrStringList (TF_OperationDescription desc, string attr_name, IntPtr [] values, UIntPtr [] lengths, int num_values); public TFOperationDesc SetAttr (string attrName, string [] values) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (values == null) throw new ArgumentNullException (nameof (values)); @@ -1758,9 +1805,9 @@ public TFOperationDesc SetAttr (string attrName, string [] values) public TFOperationDesc SetAttr (string attrName, long value) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); TF_SetAttrInt (handle, attrName, value); return this; @@ -1772,9 +1819,9 @@ public TFOperationDesc SetAttr (string attrName, long value) public TFOperationDesc SetAttr (string attrName, long [] values) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (values == null) throw new ArgumentNullException (nameof (values)); @@ -1790,9 +1837,9 @@ public TFOperationDesc SetAttr (string attrName, long [] values) public TFOperationDesc SetAttr (string attrName, float value) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); TF_SetAttrFloat (handle, attrName, value); return this; @@ -1804,9 +1851,9 @@ public TFOperationDesc SetAttr (string attrName, float value) public TFOperationDesc SetAttr (string attrName, float [] values) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (values == null) throw new ArgumentNullException (nameof (values)); @@ -1821,9 +1868,9 @@ public TFOperationDesc SetAttr (string attrName, float [] values) public TFOperationDesc SetAttr (string attrName, bool value) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); TF_SetAttrBool (handle, attrName, (byte)(value ? 1 : 0)); return this; @@ -1835,9 +1882,9 @@ public TFOperationDesc SetAttr (string attrName, bool value) public TFOperationDesc SetAttr (string attrName, bool [] values) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (values == null) throw new ArgumentNullException (nameof (values)); @@ -1852,9 +1899,9 @@ public TFOperationDesc SetAttr (string attrName, bool [] values) public TFOperationDesc SetAttrType (string attrName, TFDataType dataType) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); TF_SetAttrType (handle, attrName, dataType); return this; @@ -1866,9 +1913,9 @@ public TFOperationDesc SetAttrType (string attrName, TFDataType dataType) public TFOperationDesc SetAttrType (string attrName, params TFDataType [] dataType) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (dataType == null) throw new ArgumentNullException (nameof (dataType)); @@ -1884,9 +1931,9 @@ public TFOperationDesc SetAttrType (string attrName, params TFDataType [] dataTy public TFOperationDesc SetAttrShape (string attrName, TFShape shape) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (shape == null || shape.dims == null) TF_SetAttrShape (handle, attrName, null, -1); @@ -1901,9 +1948,9 @@ public TFOperationDesc SetAttrShape (string attrName, TFShape shape) public TFOperationDesc SetAttrShape (string attrName, TFShape [] shapeList) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (shapeList == null) throw new ArgumentNullException (nameof (shapeList)); @@ -1936,9 +1983,9 @@ public TFOperationDesc SetAttrShape (string attrName, TFShape [] shapeList) static extern unsafe void TF_SetAttrTensorShapeProto (TF_OperationDescription desc, string attr_name, IntPtr proto, size_t proto_len, TF_Status status); public TFOperationDesc SetAttrTensorShapeProto (string attrName, IntPtr proto, size_t protoLen, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); TF_SetAttrTensorShapeProto (handle, attrName, proto, protoLen, cstatus.handle); cstatus.CheckMaybeRaise (status); return this; @@ -1955,9 +2002,9 @@ public TFOperationDesc SetAttrTensorShapeProto (string attrName, IntPtr proto, s public TFOperationDesc SetAttr (string attrName, TFTensor tensor, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (tensor == null) throw new ArgumentNullException ("tensor"); @@ -1973,9 +2020,9 @@ public TFOperationDesc SetAttr (string attrName, TFTensor tensor, TFStatus statu static extern unsafe void TF_SetAttrTensorList (TF_OperationDescription desc, string attr_name, IntPtr [] values, int num_values, TF_Status status); public TFOperationDesc SetAttr (string attrName, TFTensor [] tensor, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (tensor == null) throw new ArgumentNullException (nameof (tensor)); @@ -2006,9 +2053,9 @@ public TFOperationDesc SetAttr (string attrName, TFTensor [] tensor, TFStatus st /// Optional status, on failure the operation is not added to the graph. If you pass null (the default), this operation throws on error conditions. public TFOperation FinishOperation (TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); var h = TF_FinishOperation (handle, cstatus.handle); cstatus.CheckMaybeRaise (status); handle = IntPtr.Zero; @@ -2029,9 +2076,9 @@ public TFOperation FinishOperation (TFStatus status = null) /// The value for the attribute. public void SetAttribute (string attrName, string value) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (attrName == null) + AssertNotDisposed(); + + if (attrName == null) throw new ArgumentNullException (nameof (attrName)); if (value == null) throw new ArgumentNullException (nameof (value)); @@ -2053,11 +2100,11 @@ public partial class TFOperation { internal IntPtr handle; - /// - /// Gets the handle to the unmanaged TF_Operation object. - /// - /// The handle. - public IntPtr Handle => handle; + /// + /// Gets the handle to the unmanaged TF_Operation object. + /// + /// The handle. + public IntPtr Handle => handle; // Pointer to the graph, to keep it from collecting if there are TFOperations alive. internal TFGraph graph; @@ -2065,24 +2112,35 @@ public partial class TFOperation internal TFOperation (TFGraph graph, IntPtr handle) { this.handle = handle; - this.graph = graph; + this.graph = TFSharpDebug.AssertNotNull(graph, nameof(graph)); } + //TODO: Add check for Graph as well (if we want to propagate the graph with inputs/outputs). + private bool IsValid => handle != IntPtr.Zero && !graph.HandleIsDisposed; + + internal void AssertIsValid() + { + if (graph.HandleIsDisposed) + handle = TFNative.Nullptr; + if (handle == TFNative.Nullptr) + TFDisposable.ObjectDisposedException(); + } + // extern const char * TF_OperationName (TF_Operation *oper); [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe IntPtr TF_OperationName (TF_Operation oper); - /// - /// The name for this operation/ - /// - /// The name. - public string Name => handle == IntPtr.Zero ? "" : TF_OperationName (handle).GetStr (); + /// + /// The name for this operation/ + /// + /// The name. + public string Name => IsValid ? TF_OperationName(handle).GetStr() : ""; // extern const char * TF_OperationOpType (TF_Operation *oper); [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe IntPtr TF_OperationOpType (TF_Operation oper); - public string OpType => handle == IntPtr.Zero ? "" : TF_OperationOpType (handle).GetStr (); + public string OpType => IsValid ? TF_OperationOpType(handle).GetStr() : ""; // extern const char * TF_OperationDevice (TF_Operation *oper); [DllImport (NativeBinding.TensorFlowLibrary)] @@ -2094,11 +2152,11 @@ internal TFOperation (TFGraph graph, IntPtr handle) [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe int TF_OperationNumOutputs (TF_Operation oper); - /// - /// Gets the number of outputs on this operation. - /// - /// The number outputs. - public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs (handle); + /// + /// Gets the number of outputs on this operation. + /// + /// The number outputs. + public int NumOutputs => IsValid ? TF_OperationNumOutputs(handle) : -1; // extern int TF_OperationOutputListLength (TF_Operation *oper, const char *arg_name, TF_Status *status); @@ -2107,8 +2165,8 @@ internal TFOperation (TFGraph graph, IntPtr handle) public int OutputListLength (string argName, TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); var res = TF_OperationOutputListLength (handle, argName, cstatus.handle); cstatus.CheckMaybeRaise (status); @@ -2132,8 +2190,8 @@ public int OutputListLength (string argName, TFStatus status = null) public int InputListLength (string argName, TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); var res = TF_OperationInputListLength (handle, argName, cstatus.handle); cstatus.CheckMaybeRaise (status); @@ -2170,27 +2228,53 @@ public int InputListLength (string argName, TFStatus status = null) /// Get the list of operations that have this operation as a control input. /// /// The control outputs. - public TFOperation [] ControlOutputs { - get { - var n = NumControlOutputs; - var arr = new IntPtr [n]; - TF_OperationGetControlOutputs (handle, arr, n); - var ret = new TFOperation [n]; - for (int i = 0; i < n; i++) - ret [i] = new TFOperation (graph, arr [i]); - return ret; - } - } - - // extern TF_AttrMetadata TF_OperationGetAttrMetadata (TF_Operation *oper, const char *attr_name, TF_Status *status); - [DllImport (NativeBinding.TensorFlowLibrary)] + public TFOperation[] ControlOutputs + { + get + { + AssertIsValid(); + var n = NumControlOutputs; + var arr = new IntPtr[n]; + TF_OperationGetControlOutputs(handle, arr, n); + var ret = new TFOperation[n]; + for (int i = 0; i < n; i++) + ret[i] = new TFOperation(graph, arr[i]); + return ret; + } + } + + /// + /// Gets the inputs of this operation. + /// + public TFInput[] Inputs + { + get + { + AssertIsValid(); + return Enumerable.Range(0, NumInputs).Select(index => new TFInput(this, index)).ToArray(); + } + } + + /// + /// Gets the specified input for this oepration. + /// + /// The input index. + /// The input. + /// If the index specified is not within the range [0, NumInputs). + public TFInput GetInput(int index) + { + AssertIsValid(); + return new TFInput(this, RangeChecks.Assert(index, nameof(index), 0, NumInputs)); + } + + // extern TF_AttrMetadata TF_OperationGetAttrMetadata (TF_Operation *oper, const char *attr_name, TF_Status *status); + [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe TFAttributeMetadata TF_OperationGetAttrMetadata (TF_Operation oper, string attr_name, TF_Status status); public TFAttributeMetadata GetAttributeMetadata (string attrName, TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); var x = TF_OperationGetAttrMetadata (handle, attrName, cstatus.handle); cstatus.CheckMaybeRaise (status); return x; @@ -2246,9 +2330,8 @@ public TFAttributeMetadata GetAttributeMetadata (string attrName, TFStatus statu /// The value (type) of the attribute. public unsafe TFDataType GetAttributeType (string attrName, TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); TFDataType type_value = new TFDataType (); TF_OperationGetAttrType (handle, attrName, &type_value, cstatus.Handle); cstatus.CheckMaybeRaise (status); @@ -2273,9 +2356,8 @@ public unsafe TFDataType GetAttributeType (string attrName, TFStatus status = nu /// Status of the operations public unsafe TFShape GetAttributeShape(string attr_name, int num_dims, TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); long [] shape_value = new long [num_dims]; TF_OperationGetAttrShape (handle, attr_name, shape_value, num_dims, cstatus.handle); cstatus.CheckMaybeRaise (status); @@ -2326,9 +2408,8 @@ public unsafe TFShape GetAttributeShape(string attr_name, int num_dims, TFStatus /// public TFBuffer ToNodeDef (TFStatus status = null) { - if (handle == IntPtr.Zero) - TFDisposable.ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertIsValid(); + var cstatus = TFStatus.Setup (status); var r = new TFBuffer (); unsafe { @@ -2346,11 +2427,14 @@ public TFBuffer ToNodeDef (TFStatus status = null) /// Returns the handle to the idx-th output of the operation. /// /// Index of the output in the operation. - public TFOutput this [int idx] { - get { - return new TFOutput (this, idx); - } - } + public TFOutput this[int idx] + { + get + { + AssertIsValid(); + return new TFOutput(this, idx); + } + } } @@ -2433,14 +2517,13 @@ internal override void NativeDispose (IntPtr handle) public void SetPrefix (string prefix) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); + AssertNotDisposed(); TF_ImportGraphDefOptionsSetPrefix (handle, prefix); } // extern void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions *opts, const char* src_name, int src_index, TF_Output dst); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions opts, string src_name, int src_index, TFOutput dst); + static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions opts, string src_name, int src_index, TF_Output dst); /// @@ -2456,9 +2539,8 @@ public void SetPrefix (string prefix) /// public void AddInputMapping (string srcName, int srcIndex, TFOutput dst) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - TF_ImportGraphDefOptionsAddInputMapping (handle, srcName, srcIndex, dst); + AssertNotDisposed(); + TF_ImportGraphDefOptionsAddInputMapping(handle, srcName, srcIndex, dst.NativeRepresentation); } [DllImport (NativeBinding.TensorFlowLibrary)] @@ -2470,10 +2552,10 @@ public void AddInputMapping (string srcName, int srcIndex, TFOutput dst) /// This operation should exist in the graph being imported to. public void AddControlDependency (TFOperation operation) { - if (operation == null) + AssertNotDisposed(); + + if (operation == null) throw new ArgumentNullException (nameof (operation)); - if (handle == IntPtr.Zero) - ObjectDisposedException (); TF_ImportGraphDefOptionsAddControlDependency (handle, operation.handle); } @@ -2492,10 +2574,10 @@ public void AddControlDependency (TFOperation operation) /// public void AddReturnOutput (string operName, int index) { - if (operName == null) + AssertNotDisposed(); + + if (operName == null) throw new ArgumentNullException (nameof (operName)); - if (handle == IntPtr.Zero) - ObjectDisposedException (); TF_ImportGraphDefOptionsAddReturnOutput (handle, operName, index); } @@ -2507,10 +2589,11 @@ public void AddReturnOutput (string operName, int index) /// /// The number return outputs. public int NumReturnOutputs { - get { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - return TF_ImportGraphDefOptionsNumReturnOutputs (handle); + get + { + AssertNotDisposed(); + + return TF_ImportGraphDefOptionsNumReturnOutputs (handle); } } @@ -2528,9 +2611,9 @@ public void AddReturnOutput (string operName, int index) /// public void RemapControlDependency (string srcName, TFOperation destination) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (srcName == null) + AssertNotDisposed(); + + if (srcName == null) throw new ArgumentNullException (nameof (srcName)); if (destination == null) throw new ArgumentNullException (nameof (destination)); @@ -2554,10 +2637,9 @@ public void RemapControlDependency (string srcName, TFOperation destination) /// public void SetUniquifyNames (bool uniquifyNames) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - - TF_ImportGraphDefOptionsSetUniquifyNames (handle, uniquifyNames ? (byte) 1 : (byte) 0); + AssertNotDisposed(); + + TF_ImportGraphDefOptionsSetUniquifyNames(handle, uniquifyNames ? (byte) 1 : (byte) 0); } [DllImport (NativeBinding.TensorFlowLibrary)] @@ -2572,9 +2654,9 @@ public void SetUniquifyNames (bool uniquifyNames) /// public void SetUniquifyPrefix (bool uniquifyPrefix) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - TF_ImportGraphDefOptionsSetUniquifyPrefix (handle, uniquifyPrefix ? (byte)1 : (byte)0); + AssertNotDisposed(); + + TF_ImportGraphDefOptionsSetUniquifyPrefix(handle, uniquifyPrefix ? (byte)1 : (byte)0); } } @@ -2760,9 +2842,9 @@ public static TFSession FromSavedModel (TFSessionOptions sessionOptions, TFBuffe /// public void CloseSession (TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); TF_CloseSession (handle, cstatus.handle); cstatus.CheckMaybeRaise (status); } @@ -2777,9 +2859,9 @@ public void CloseSession (TFStatus status = null) /// Status. public void DeleteSession (TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - var cstatus = TFStatus.Setup (status); + AssertNotDisposed(); + + var cstatus = TFStatus.Setup (status); TF_DeleteSession (handle, cstatus.handle); cstatus.CheckMaybeRaise (status); } @@ -2793,7 +2875,7 @@ internal override void NativeDispose (IntPtr handle) // extern void TF_SessionRun (TF_Session *session, const TF_Buffer *run_options, const TF_Output *inputs, TF_Tensor *const *input_values, int ninputs, const TF_Output *outputs, TF_Tensor **output_values, int noutputs, const TF_Operation *const *target_opers, int ntargets, TF_Buffer *run_metadata, TF_Status *); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_SessionRun (TF_Session session, LLBuffer* run_options, TFOutput [] inputs, TF_Tensor [] input_values, int ninputs, TFOutput [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, LLBuffer* run_metadata, TF_Status status); + static extern unsafe void TF_SessionRun (TF_Session session, LLBuffer* run_options, TF_Output [] inputs, TF_Tensor [] input_values, int ninputs, TF_Output [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, LLBuffer* run_metadata, TF_Status status); /// @@ -3016,8 +3098,9 @@ public TFTensor Run (TFOutput operation, TFStatus status = null) /// public Runner GetRunner () { + AssertNotDisposed(); return new Runner (this); - } + } /// /// Executes a pipeline given the specified inputs, inputValues, outputs, targetOpers, runMetadata and runOptions. @@ -3034,9 +3117,9 @@ public Runner GetRunner () /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public TFTensor [] Run (TFOutput [] inputs, TFTensor [] inputValues, TFOutput [] outputs, TFOperation [] targetOpers = null, TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (inputs == null) + AssertNotDisposed(); + + if (inputs == null) throw new ArgumentNullException (nameof (inputs)); if (inputValues == null) throw new ArgumentNullException (nameof (inputValues)); @@ -3068,7 +3151,19 @@ public Runner GetRunner () unsafe { - TF_SessionRun (handle, runOptions == null ? null : runOptions.LLBuffer, inputs, ivals, iLen, outputs, ovals, oLen, topers, tLen, runMetadata == null ? null : runMetadata.LLBuffer, cstatus.handle); + TF_SessionRun ( + handle, + runOptions == null ? null : runOptions.LLBuffer, + TFOutput.ToNative(inputs), + ivals, + iLen, + TFOutput.ToNative(outputs), + ovals, + oLen, + topers, + tLen, + runMetadata == null ? null : runMetadata.LLBuffer, + cstatus.handle); } cstatus.CheckMaybeRaise (status); @@ -3084,7 +3179,7 @@ public Runner GetRunner () // extern void TF_SessionPRunSetup (TF_Session, const TF_Output *inputs, int ninputs, const TF_Output *outputs, int noutputs, const TF_Operation *const *target_opers, int ntargets, const char **handle, TF_Status *); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_SessionPRunSetup (TF_Session session, TFOutput [] inputs, int ninputs, TFOutput [] outputs, int noutputs, TF_Operation [] target_opers, int ntargets, out IntPtr returnHandle, TF_Status status); + static extern unsafe void TF_SessionPRunSetup (TF_Session session, TF_Output [] inputs, int ninputs, TF_Output [] outputs, int noutputs, TF_Operation [] target_opers, int ntargets, out IntPtr returnHandle, TF_Status status); [DllImport (NativeBinding.TensorFlowLibrary)] static extern unsafe void TF_DeletePRunHandle (IntPtr partialRunHandle); @@ -3120,9 +3215,9 @@ void IDisposable.Dispose () /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. public PartialRunToken PartialRunSetup (TFOutput [] inputs, TFOutput [] outputs, TFOperation [] targetOpers, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (inputs == null) + AssertNotDisposed(); + + if (inputs == null) throw new ArgumentNullException (nameof (inputs)); if (outputs == null) throw new ArgumentNullException (nameof (outputs)); @@ -3136,19 +3231,19 @@ public PartialRunToken PartialRunSetup (TFOutput [] inputs, TFOutput [] outputs, for (int i = 0; i < tLen; i++) topers [i] = targetOpers [i].handle; - TF_SessionPRunSetup (handle, inputs, inputs.Length, outputs, outputs.Length, topers, tLen, out returnHandle, cstatus.handle); + TF_SessionPRunSetup (handle, TFOutput.ToNative(inputs), inputs.Length, TFOutput.ToNative(outputs), outputs.Length, topers, tLen, out returnHandle, cstatus.handle); cstatus.CheckMaybeRaise (status); return new PartialRunToken () { token = returnHandle }; } // extern void TF_SessionPRun (TF_Session *, const char *handle, const TF_Output *inputs, TF_Tensor *const *input_values, int ninputs, const TF_Output *outputs, TF_Tensor **output_values, int noutputs, const TF_Operation *const *target_opers, int ntargets, TF_Status *); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe void TF_SessionPRun (TF_Session session, IntPtr partialHandle, TFOutput [] inputs, TF_Tensor [] input_values, int ninputs, TFOutput [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, TF_Status status); + static extern unsafe void TF_SessionPRun (TF_Session session, IntPtr partialHandle, TF_Output [] inputs, TF_Tensor [] input_values, int ninputs, TF_Output [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, TF_Status status); public TFTensor [] PartialRun (PartialRunToken token, TFOutput [] inputs, TFTensor [] inputValues, TFOutput [] outputs, TFOperation [] targetOpers, TFStatus status = null) { - if (handle == IntPtr.Zero) - ObjectDisposedException (); - if (inputs == null) + AssertNotDisposed(); + + if (inputs == null) throw new ArgumentNullException (nameof (inputs)); if (inputValues == null) throw new ArgumentNullException (nameof (inputValues)); @@ -3176,7 +3271,18 @@ public PartialRunToken PartialRunSetup (TFOutput [] inputs, TFOutput [] outputs, unsafe { - TF_SessionPRun (handle, token.token, inputs, ivals, iLen, outputs, ovals, oLen, topers, tLen, cstatus.handle); + TF_SessionPRun( + handle, + token.token, + TFOutput.ToNative(inputs), + ivals, + iLen, + TFOutput.ToNative(outputs), + ovals, + oLen, + topers, + tLen, + cstatus.handle); } cstatus.CheckMaybeRaise (status); @@ -3209,9 +3315,10 @@ public PartialRunToken PartialRunSetup (TFOutput [] inputs, TFOutput [] outputs, /// public TFOutput RestoreTensor (string filename, string tensor, TFDataType type) { - return this.Graph.Restore (this.Graph.Const (TFTensor.CreateString (Encoding.UTF8.GetBytes (filename))), - this.Graph.Const (TFTensor.CreateString (Encoding.UTF8.GetBytes (tensor))), - type); + return this.Graph.Restore( + this.Graph.Const(TFTensor.CreateString(Encoding.UTF8.GetBytes(filename))), + this.Graph.Const(TFTensor.CreateString(Encoding.UTF8.GetBytes(tensor))), + type); } /// @@ -3295,8 +3402,9 @@ public static TFLibrary FromFile (string libraryFile, TFStatus status = null) /// The buffer contains a ProtocolBuffer encoded payload, you need a ProtocolBuffer reader to process the contents. TFBuffer GetOpList () { + AssertNotDisposed(); return new TFBuffer (TF_GetOpList (handle).data); - } + } // extern void TF_DeleteLibraryHandle (TF_Library *lib_handle); [DllImport (NativeBinding.TensorFlowLibrary)] @@ -3578,86 +3686,165 @@ public enum TFCode : uint DataLoss = 15 } - /// - /// Represents a specific input of an operation. - /// [StructLayout (LayoutKind.Sequential)] + internal struct TF_Input + { + public unsafe TF_Operation Operation; + public int Index; + + } + + /// + /// Represents a specific input of an operation. + /// public struct TFInput { - /// + private TFGraph Graph { get; } + internal TF_Input NativeRepresentation { get; } + + /// /// The operation that this input is for /// - public unsafe TF_Operation Operation; - - /// + public unsafe TF_Operation Operation => NativeRepresentation.Operation; + /// /// The index of the output within the Operation /// - public int Index; + public int Index => NativeRepresentation.Index; - // extern TF_Output TF_OperationInput (TF_Input oper_in); - [DllImport (NativeBinding.TensorFlowLibrary)] - static extern TFOutput TF_OperationInput (TFInput oper_in); + internal TFInput(TF_Input native, TFGraph graph) + { + NativeRepresentation = native; + Graph = TFSharpDebug.AssertNotNull(graph, nameof(graph)); + } - public TFOutput GetOutput (TFInput operIn) - { - return TF_OperationInput (operIn); - } + internal TFInput(TFOperation operation, int index) + { + NativeRepresentation = new TF_Input + { + Operation = TFSharpDebug.AssertNotNull(operation, nameof(operation)).Handle, + Index = TFSharpDebug.AssertWithinLimits(index, nameof(index), 0, operation.NumInputs) + }; + Graph = operation.graph; + } - // extern TF_DataType TF_OperationInputType (TF_Input oper_in); + // extern TF_Output TF_OperationInput (TF_Input oper_in); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern TFDataType TF_OperationInputType (TFInput oper_in); - - public TFDataType InputType => TF_OperationInputType (this); - - } - - /// - /// Represents a specific output of an operation on a tensor. - /// - /// - /// - /// TFOutput objects represent one of the outputs of an operation in the graph - /// (TFGraph). Outputs have a data type, and eventually a shape that you can - /// retrieve by calling the method. - /// - /// - /// These can be passed as an input argument to a function for adding operations - /// to a graph, or to the TFSession's Run and GetRunner method as values to be - /// fetched. - /// - /// - [StructLayout (LayoutKind.Sequential)] + static extern TFOutput TF_OperationInput (TF_Input oper_in); + + /// + /// Gets the output connected to this TFInput. + /// + /// The TFOutput connected to this TFInput. + public TFOutput Output + { + get + { + AssertValid(); + return TF_OperationInput(NativeRepresentation); + } + } + + // extern TF_DataType TF_OperationInputType (TF_Input oper_in); + [DllImport (NativeBinding.TensorFlowLibrary)] + static extern TFDataType TF_OperationInputType (TF_Input oper_in); + + /// + /// The type of this input. + /// + public TFDataType InputType + { + get + { + AssertValid(); + return TF_OperationInputType(NativeRepresentation); + } + } + + private void AssertValid() => Graph.AssertNotDisposed(); + + internal static TFInput[] FromNative(TF_Input[] native, TFGraph graph) + { + var managed = new TFInput[native.Length]; + for (int i = 0; i < native.Length; ++i) + managed[i] = new TFInput(native[i], graph); + return managed; + } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct TF_Output + { + public unsafe TF_Operation LLOperation; + + /// + /// The index of the output within the operation. + /// + public int Index; + } + + /// + /// Represents a specific output of an operation on a tensor. + /// + /// + /// + /// TFOutput objects represent one of the outputs of an operation in the graph + /// (TFGraph). Outputs have a data type, and eventually a shape that you can + /// retrieve by calling the method. + /// + /// + /// These can be passed as an input argument to a function for adding operations + /// to a graph, or to the TFSession's Run and GetRunner method as values to be + /// fetched. + /// + /// public struct TFOutput { - unsafe TF_Operation LLOperation; - - /// - /// The index of the output within the operation. - /// - public int Index; - - // extern int TF_OperationOutputNumConsumers (TF_Output oper_out); - [DllImport (NativeBinding.TensorFlowLibrary)] - static extern int TF_OperationOutputNumConsumers (TFOutput oper_out); - - /// - /// Gets the number consumers. - /// - /// The number consumers. - /// - /// This number can change when new operations are added to the graph. - /// - public int NumConsumers => TF_OperationOutputNumConsumers (this); - - // extern TF_DataType TF_OperationOutputType (TF_Output oper_out); - [DllImport (NativeBinding.TensorFlowLibrary)] - static extern TFDataType TF_OperationOutputType (TFOutput oper_out); - - /// - /// Gets the type of the output. - /// - /// The type of the output. - public TFDataType OutputType => LLOperation == IntPtr.Zero ? TFDataType.Unknown : TF_OperationOutputType (this); + internal TF_Output NativeRepresentation { get; private set; } + private TFGraph Graph { get; set; } + + /// + /// The index of the output within the operation. + /// + public int Index => NativeRepresentation.Index; + internal unsafe TF_Operation LLOperation => NativeRepresentation.LLOperation; + + // extern int TF_OperationOutputNumConsumers (TF_Output oper_out); + [DllImport (NativeBinding.TensorFlowLibrary)] + static extern int TF_OperationOutputNumConsumers (TF_Output oper_out); + + /// + /// Gets the number consumers. + /// + /// The number consumers. + /// + /// This number can change when new operations are added to the graph. + /// + public int NumConsumers + { + get + { + AssertValid(); + return TF_OperationOutputNumConsumers(NativeRepresentation); + } + } + + // extern TF_DataType TF_OperationOutputType (TF_Output oper_out); + [DllImport (NativeBinding.TensorFlowLibrary)] + static extern TFDataType TF_OperationOutputType (TF_Output oper_out); + + /// + /// Gets the type of the output. + /// + /// The type of the output. + public TFDataType OutputType => IsValid ? TF_OperationOutputType(NativeRepresentation) : TFDataType.Unknown; + + internal TFOutput(TF_Output native, TFGraph graph) + { + if (native.LLOperation == TFNative.Nullptr) + throw new ArgumentNullException("Outputs does not have a valid operation pointer"); + NativeRepresentation = native; + Graph = graph; + } /// /// Initializes a new TFOutput instance. @@ -3668,8 +3855,8 @@ public TFOutput (TFOperation operation, int index = 0) { if (operation == null) throw new ArgumentNullException (nameof (operation)); - LLOperation = operation.handle; - Index = index; + NativeRepresentation = new TF_Output { LLOperation = operation.handle, Index = index }; + Graph = operation.graph; } /// @@ -3679,15 +3866,15 @@ public TFOutput (TFOperation operation, int index = 0) /// The index of the output within the operation, if not specified, it defaults to zero. public TFOutput (TFOutput output, int index = 0) { - if (output.LLOperation == null) - throw new ArgumentNullException ("Outputs does not have a valid operation pointer"); - LLOperation = output.LLOperation; - Index = index; + if (output.LLOperation == TFNative.Nullptr) + throw new ArgumentNullException("Outputs does not have a valid operation pointer"); + NativeRepresentation = new TF_Output { LLOperation = output.LLOperation, Index = index }; + Graph = output.Graph; } // extern int TF_OperationOutputConsumers (TF_Output oper_out, TF_Input *consumers, int max_consumers); [DllImport (NativeBinding.TensorFlowLibrary)] - static extern unsafe int TF_OperationOutputConsumers (TFOutput oper_out, TFInput* consumers, int max_consumers); + static extern unsafe int TF_OperationOutputConsumers (TF_Output oper_out, TF_Input* consumers, int max_consumers); /// /// Get list of all current consumers of a specific output of an operation @@ -3698,32 +3885,192 @@ public TFOutput (TFOutput output, int index = 0) /// an operation. /// This can return null if the TFOutput does not point to a valid object. /// - public TFInput [] OutputConsumers { - get { - var result = new TFInput [NumConsumers]; - unsafe - { - fixed (TFInput* first = &result [0]) - TF_OperationOutputConsumers (this, first, result.Length); - } - return result; - } - } - - /// - /// The associated operation. - /// - /// The operation. - public TFOperation Operation => new TFOperation (null, LLOperation); - - /// - /// Returns a that represents the current . - /// - /// A that represents the current . - public override string ToString () + public TFInput[] OutputConsumers + { + get + { + AssertValid(); + var result = new TF_Input[NumConsumers]; + unsafe + { + fixed (TF_Input* first = result) + TF_OperationOutputConsumers(NativeRepresentation, first, result.Length); + } + return TFInput.FromNative(result, Graph); + } + } + + /// + /// The associated operation. + /// + /// The operation. + public TFOperation Operation + { + get + { + AssertValid(); + return new TFOperation(Graph, LLOperation); + } + } + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString () { return string.Format ("[{3} Index={1} Operation={2} (0x{0:X})]", (long) LLOperation, Index, Operation, OutputType); } + + private bool IsValid => LLOperation != TFNative.Nullptr && !Graph.HandleIsDisposed; + private void AssertValid() => Graph.AssertNotDisposed(); + + /// + /// Performs element-wise addition a + b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator+(TFOutput a, TFOutput b) + { + a.AssertValid(); + return a.Graph.Add(a, b); + } + /// + /// Performs element-wise subtraction a - b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator -(TFOutput a, TFOutput b) + { + a.AssertValid(); + return a.Graph.Sub(a, b); + } + /// + /// Performs element-wise division a / b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator /(TFOutput a, TFOutput b) + { + a.AssertValid(); + return a.Graph.Div(a, b); + } + /// + /// Performs multiplication of a * b. + /// Attempts to perform matrix multiplication where applicable. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator *(TFOutput a, TFOutput b) + { + a.AssertValid(); + var aShape = a.Graph.GetTensorShape(a); + var bShape = a.Graph.GetTensorShape(b); + + int adims = aShape.NumDimensions; + int bdims = bShape.NumDimensions; + // --- Check for matrix multiplication (TODO: transpose check?) + if (adims > 1 && bdims > 1) + return a.Graph.MatMul(a, b); + else if (adims > 1 && bdims == 1) + return a.Graph.MatMul(a, b); + else if (adims == 1 && bdims > 1) + return a.Graph.MatMul(a, b); + else + return a.Graph.Mul(a, b); + } + + /// + /// Performs element-wise addition a + b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator +(TFOutput output, TFTensor value) => output + CreateConst(output.Graph, value); + /// + /// Performs element-wise subtraction a - b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator -(TFOutput output, TFTensor value) => output - CreateConst(output.Graph, value); + /// + /// Performs multiplication of a * b. + /// Attempts to perform matrix multiplication where applicable. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator *(TFOutput output, TFTensor value) => output * CreateConst(output.Graph, value); + /// + /// Performs element-wise division a / b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator /(TFOutput output, TFTensor value) => output / CreateConst(output.Graph, value); + + /// + /// Performs element-wise addition a + b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator +(TFTensor value, TFOutput output) => CreateConst(output.Graph, value) + output; + /// + /// Performs element-wise subtraction a - b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator -(TFTensor value, TFOutput output) => CreateConst(output.Graph, value) - output; + /// + /// Performs multiplication of a * b. + /// Attempts to perform matrix multiplication where applicable. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator *(TFTensor value, TFOutput output) => CreateConst(output.Graph, value) * output; + /// + /// Performs element-wise division a / b. + /// + /// The first operand. + /// The second operand. + /// The resulting operation. + public static TFOutput operator /(TFTensor value, TFOutput output) => CreateConst(output.Graph, value) / output; + + private static TFOutput CreateConst(TFGraph graph, TFTensor tensor) + { + graph.AssertNotDisposed(); + return graph.Const(tensor); + } + + internal static TFOutput[] FromNative(TF_Output[] native, TFGraph graph) => FromNative(native, new TFOutput[native.Length], graph); + internal static TFOutput[] FromNative(TF_Output[] native, TFOutput[] dst, TFGraph graph) + { + if (native.Length != dst.Length) + throw new ArgumentException(); + + for (int i = 0; i < native.Length; ++i) + dst[i] = new TFOutput(native[i], graph); + return dst; + } + + internal static TF_Output[] ToNative(TFOutput[] managed) => ToNative(managed, new TF_Output[managed.Length]); + internal static TF_Output[] ToNative(TFOutput[] managed, TF_Output[] dst) + { + if (managed.Length != dst.Length) + throw new ArgumentException(); + + for (int i = 0; i < managed.Length; ++i) + dst[i] = managed[i].NativeRepresentation; + return dst; + } } /// @@ -3838,6 +4185,10 @@ public override string ToString () /// public class TFShape { + /// + /// Represents an unknown dimmension size or number of dimmensions. + /// + public const long UnknownSize = -1; /// /// Represents an unknown number of dimensions in the tensor. /// @@ -3850,7 +4201,9 @@ public class TFShape /// The scalar. public static TFShape Scalar => new TFShape (new long [0]); - internal long [] dims; + internal readonly long [] dims; + + private bool HasUnknownNumberOfDimensions => dims == null; /// /// Initializes a new instance of the class. @@ -3864,36 +4217,39 @@ public class TFShape /// public TFShape (params long [] args) { - this.dims = args; + dims = args; + AssertValidShape(); } - /// - /// Gets the length of the specified dimension in the tensor - /// - /// The length, -1 for shapes that have an unknown dimension. - /// Dimension. - public int GetLength (int dimension) => dims == null ? -1 : dims.GetLength (dimension); + /// + /// Gets the length of the specified dimension in the tensor + /// + /// The length, -1 for shapes that have an unknown dimension. + /// Dimension. + public int GetLength(int dimension) => (int)(HasUnknownNumberOfDimensions ? UnknownSize : dims[dimension]); /// /// Number of dimensions represented by this shape. /// /// The number dimensions, -1 if the number of dimensions is unknown, 0 if the shape represent a scalar, 1 for a vector, 2 for a matrix and so on.. - public int NumDimensions => dims == null ? -1 : dims.Length; + public int NumDimensions => HasUnknownNumberOfDimensions ? (int)UnknownSize : dims.Length; /// /// Gets a value indicating whether all the dimensions in the are fully specified. /// /// true if is fully specified; otherwise, false. - public bool IsFullySpecified { - get { - if (dims == null) - return false; - foreach (var j in dims) - if (j == -1) - return false; - return true; - } - } + public bool IsFullySpecified + { + get + { + if (HasUnknownNumberOfDimensions) + return false; + foreach (var j in dims) + if (j == UnknownSize) + return false; + return true; + } + } /// /// Returns the shape as an array @@ -3901,44 +4257,48 @@ public TFShape (params long [] args) /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension. public long [] ToArray () { - if (dims == null) + if (HasUnknownNumberOfDimensions) return null; var ret = (long [])dims.Clone (); return ret; } - /// - /// Returns the shape as an array - /// - /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension. - public int [] ToIntArray () - { - if (dims == null) - return null; + /// + /// Returns the shape as an array + /// + /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension. + public int[] ToIntArray() + { + if (HasUnknownNumberOfDimensions) + return null; - var ret = new int [dims.Length]; - for (int i = 0; i < dims.Length; i++) { - checked { - ret [i] = (int) dims [i]; - } - } - return ret; - } + var ret = new int[dims.Length]; + for (int i = 0; i < dims.Length; i++) + { + checked + { + ret[i] = (int)dims[i]; + } + } + return ret; + } /// /// Gets a value indicating whether one of the dimensions in the shape is larger than Int32.MaxValue. /// /// true if is long array; otherwise, false. - public bool IsLongArray { - get { - foreach (var l in dims) - if (l > Int32.MaxValue) - return true; + public bool IsLongArray + { + get + { + foreach (var l in dims) + if (l > Int32.MaxValue) + return true; - return false; - } - } + return false; + } + } /// /// Returns a that represents the current . @@ -3946,9 +4306,9 @@ public TFShape (params long [] args) /// A that represents the current . public override string ToString () { - if (dims == null) + if (HasUnknownNumberOfDimensions) return "unknown"; - return "[" + String.Join (", ", dims.Select (x => x == -1 ? "?" : x.ToString ())) + "]"; + return "[" + String.Join (", ", dims.Select (x => x == UnknownSize ? "?" : x.ToString ())) + "]"; } /// @@ -3957,22 +4317,30 @@ public override string ToString () /// Index. public long this [int idx] => dims [idx]; - /// - /// Returns the shape as a 1-dimensional tensor with each element corresponding to the specified shape dimension. - /// - /// The tensor. - public TFTensor AsTensor () - { - return new TFTensor (ToIntArray ()); - } + /// + /// Returns the shape as a 1-dimensional tensor with each element corresponding to the specified shape dimension. + /// + /// The tensor. + public TFTensor AsTensor() => new TFTensor(ToIntArray()); - /// - /// Adds a to a , yielding a shape made up of the concatenation of the first and the second shapes. - /// - /// The first to add. - /// The second to add. - /// The that is the sum of the values of left and right. - public static TFShape operator + (TFShape left, TFShape right) + + private void AssertValidShape() + { + if (HasUnknownNumberOfDimensions) + return; + + foreach (var dim in dims) + if (dim < UnknownSize) + throw new ArgumentOutOfRangeException($"TFShape can only have non-negative or '{nameof(UnknownSize)}' dimmension sizes"); + } + + /// + /// Adds a to a , yielding a shape made up of the concatenation of the first and the second shapes. + /// + /// The first to add. + /// The second to add. + /// The that is the sum of the values of left and right. + public static TFShape operator + (TFShape left, TFShape right) { if (left == null) return right; @@ -3994,7 +4362,18 @@ public TFTensor AsTensor () { return shape.AsTensor (); } - } + + /// + /// Performs an implicit conversion from an array describing dimension sizes to a . + /// + /// The shape + public static implicit operator TFShape(long[] shape)=> new TFShape(shape); + /// + /// Performs an implicit conversion from an array to a . + /// + /// The shape + public static implicit operator TFShape(int[] shape)=> new TFShape(shape.Cast().ToArray()); + } diff --git a/tests/TensorFlowSharp.Tests.CSharp/LifetimeTests.cs b/tests/TensorFlowSharp.Tests.CSharp/LifetimeTests.cs new file mode 100644 index 00000000..cdca4204 --- /dev/null +++ b/tests/TensorFlowSharp.Tests.CSharp/LifetimeTests.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TensorFlow; +using Xunit; + +namespace TensorFlowSharp.Tests.CSharp +{ + public class LifetimeTests + { + + public class TFOperationLifetime + { + [Fact] + public void DisposingGraphMakesOperationInvalid() + { + TFOperation op; + using (var graph = new TFGraph()) + { + op = graph.Const(10).Operation; + } + + var e = Record.Exception(() => op.ControlOutputs); + + Assert.IsType(e); + } + + private static object[] BuildOpTestcase(Action op) => new[] { op }; + private static IEnumerable OperationsThatShouldThrow() + { + yield return BuildOpTestcase(op => { var foo = op.ControlOutputs; }); + yield return BuildOpTestcase(op => { var foo = op.GetAttributeMetadata(""); }); + yield return BuildOpTestcase(op => { var foo = op.GetAttributeShape("", 1); }); + yield return BuildOpTestcase(op => { var foo = op.GetAttributeType(""); }); + yield return BuildOpTestcase(op => { var foo = op.GetInput(0); }); + } + + + [Theory] + [MemberData(nameof(OperationsThatShouldThrow))] + public void AfterDisposingGraphCertainOperationsFail(Action operation) + { + TFOperation op; + using (var graph = new TFGraph()) + { + op = graph.Const(10).Operation; + } + + var e = Record.Exception(() => operation(op)); + + Assert.IsType(e); + } + + + + private static IEnumerable OperationsThatShouldNotThrow() + { + yield return BuildOpTestcase(op => { var foo = op.Name; }); + yield return BuildOpTestcase(op => { var foo = op.Handle; }); + yield return BuildOpTestcase(op => { var foo = op.NumControlInputs; }); + yield return BuildOpTestcase(op => { var foo = op.NumControlOutputs; }); + yield return BuildOpTestcase(op => { var foo = op.NumInputs; }); + yield return BuildOpTestcase(op => { var foo = op.NumOutputs; }); + } + + [Theory] + [MemberData(nameof(OperationsThatShouldNotThrow))] + public void AfterDisposingGraphCertainOperationsStillSucceeds(Action operation) + { + TFOperation op; + using (var graph = new TFGraph()) + { + op = graph.Const(10).Operation; + } + + var record = Record.Exception(() => operation(op)); + + Assert.Null(record); + } + } + + public class TFOutputLifetime + { + [Fact] + public void AfterDisposingGraphDatatypeBecomesUnknown() + { + var sut = CreateOutputAndDisposeGraph(); + + var datatype = sut.OutputType; + + Assert.Equal(TFDataType.Unknown, datatype); + } + + [Fact] + public void AfterDisposingGraphCertainOperationsAreNotValidOnOutput() + { + var sut = CreateOutputAndDisposeGraph(); + + var ex1 = Record.Exception(() => sut.NumConsumers); + var ex2 = Record.Exception(() => sut.Operation); + var ex3 = Record.Exception(() => sut.OutputConsumers); + + Assert.IsType(ex1); + Assert.IsType(ex2); + Assert.IsType(ex3); + } + + private static TFOutput CreateOutputAndDisposeGraph() + { + using(var graph = new TFGraph()) + { + return graph.Const(10); + } + } + } + } +} diff --git a/tests/TensorFlowSharp.Tests.CSharp/TFOperationTests.cs b/tests/TensorFlowSharp.Tests.CSharp/TFOperationTests.cs new file mode 100644 index 00000000..02988200 --- /dev/null +++ b/tests/TensorFlowSharp.Tests.CSharp/TFOperationTests.cs @@ -0,0 +1,43 @@ +using TensorFlow; +using Xunit; + +namespace TensorFlowSharp.Tests.CSharp +{ + + public class TFOperationTests + { + [Fact] + public void InputsEqualToTheIncommingNumberOfOutputsAreRetrievedInOrder() + { + using (var graph = new TFGraph()) + { + var a = graph.Const(0f); + var v2 = graph.Variable(graph.Const(0.6f)); + + var add = graph.Add(a, v2.Read); + + AssertExpectedNumberOfInputs(0, a.Operation); + AssertExpectedNumberOfInputs(1, v2.Read.Operation); + AssertExpectedNumberOfInputs(2, add.Operation); + } + } + + + private static void AssertExpectedNumberOfInputs(int expected, TFOperation sut) + { + Assert.Equal(expected, sut.NumInputs); + + var inputs = sut.Inputs; + + Assert.NotNull(inputs); + Assert.Equal(expected, inputs.Length); + + for(int i = 0; i < inputs.Length; ++i) + { + var input = inputs[i]; + Assert.Equal(i, input.Index); + Assert.Equal(sut.Handle, input.Operation); + } + } + } +} diff --git a/tests/TensorFlowSharp.Tests.CSharp/TFOutputTests.cs b/tests/TensorFlowSharp.Tests.CSharp/TFOutputTests.cs new file mode 100644 index 00000000..07c8fac2 --- /dev/null +++ b/tests/TensorFlowSharp.Tests.CSharp/TFOutputTests.cs @@ -0,0 +1,166 @@ +using System; +using TensorFlow; +using Xunit; + +namespace TensorFlowSharp.Tests.CSharp +{ + public class TFOutputTests + { + public class Construction + { + [Fact] + public void OutputConstructedFromOperationWillGetTheOperationsHandle() + { + using (var graph = new TFGraph()) + { + var op = graph.NoOp(); + var output = new TFOutput(op, 0); + Assert.Equal(op.Handle, output.LLOperation); + } + } + } + + public class Operators + { + public class TFOutputAndTFOutput + { + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 2)] + [InlineData(1.5, 2.75, 4.25)] + public void AdditionOperatorYieldsAddedResults(double a, double b, double expected) => RunOperation((x, y) => x + y, a, b, expected); + + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 0)] + [InlineData(1.5, 2.75, -1.25)] + public void SubtractionOperatorYieldsSubtractedResults(double a, double b, double expected) => RunOperation((x, y) => x - y, a, b, expected); + + [Theory] + [InlineData(0, 1, 0)] + [InlineData(1, 1, 1)] + [InlineData(6.5, -2, -3.25)] + public void DivisionOperatorYieldsDividedResults(double a, double b, double expected) => RunOperation((x, y) => x / y, a, b, expected); + + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 1)] + [InlineData(-1.5, 2.75, -4.125)] + public void MultiplicationOperatorWithScalarYieldsMultipliedResultsAsScalar(double a, double b, double expected) => RunOperation((x, y) => x * y, a, b, expected); + + [Fact] + public void MultiplicationOfMatricesYieldsMatrixMultiplicationResult() + { + using(var graph = new TFGraph()) + { + var a = graph.Const(new TFTensor(new[,] { { 1.0, 0.0 }, { 0.0, 2.0 } })); + var b = graph.Const(new TFTensor(new[,] { { 2.0, 1.0 }, { 3.0, 4.0 } })); + + var output = a * b; + using(var session = new TFSession(graph)) + { + var result = (double[,])session.GetRunner().Run(output).GetValue(); + Assert.Equal(2.0, result[0, 0]); + Assert.Equal(1.0, result[0, 1]); + Assert.Equal(6.0, result[1, 0]); + Assert.Equal(8.0, result[1, 1]); + } + } + } + + [Fact] + public void MultiplicationOfVectorsYieldsElementwiseMultiplicationResult() + { + using (var graph = new TFGraph()) + { + var a = graph.Const(new TFTensor(new[] { 2.0, 3.0 })); + var b = graph.Const(new TFTensor(new[] { 0.5, 4.0 })); + + var output = a * b; + using (var session = new TFSession(graph)) + { + var result = (double[])session.GetRunner().Run(output).GetValue(); + Assert.Equal(1.0, result[0]); + Assert.Equal(12.0, result[1]); + } + } + } + + private static void RunOperation(Func operation, double a, double b, double expected) + { + using (var graph = new TFGraph()) + { + var aop = graph.Const(a); + var bop = graph.Const(b); + var sut = operation(aop, bop); + AssertExpectedOutcome(graph, sut, expected); + } + } + } + + public class TFOutputAndValue + { + + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 2)] + [InlineData(1.5, 2.75, 4.25)] + public void AdditionWithFloatIsEquivalentToAdditionWithConstant(float a, float b, float expected) => + AssertSymmetricOperation((x, y) => x + y, (x, y) => x + y, val => new TFTensor(val), a, b, expected); + + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 0)] + [InlineData(1.5, 2.75, -1.25)] + public void SubtractionWithFloatIsEquivalentToSubtractionWithConstant(float a, float b, float expected) => + AssertSymmetricOperation((x, y) => x - y, (x, y) => x - y, val => new TFTensor(val), a, b, expected); + + [Theory] + [InlineData(0, 1, 0)] + [InlineData(1, 1, 1)] + [InlineData(6.5, -2, -3.25)] + public void DivisionWithFloatIsEquivalentToDivisionWithConstant(float a, float b, float expected) => + AssertSymmetricOperation((x, y) => x / y, (x, y) => x / y, val => new TFTensor(val), a, b, expected); + + [Theory] + [InlineData(0, 0, 0)] + [InlineData(1, 1, 1)] + [InlineData(-1.5, 2.75, -4.125)] + public void MultiplicationWithFloatIsEquivalentToMultiplicationWithConstant(float a, float b, float expected) => + AssertSymmetricOperation((x, y) => x * y, (x, y) => x * y, val => new TFTensor(val), a, b, expected); + + + + + private static void AssertSymmetricOperation( + Func op, + Func opRev, + Func tensorFactory, + T a, + T b, + T expected) + { + using (var graph = new TFGraph()) + { + var aop = graph.Const(tensorFactory(a)); + var bop = graph.Const(tensorFactory(b)); + + AssertExpectedOutcome(graph, op(aop, b), expected); + AssertExpectedOutcome(graph, opRev(a, bop), expected); + } + } + } + + public static void AssertExpectedOutcome(TFGraph graph, TFOutput output, T expected) + { + using (var session = new TFSession(graph)) + { + var result = session.GetRunner().Run(output); + + Assert.Equal(0, result.NumDims); + Assert.Equal(expected, (T)result.GetValue()); + } + } + } + } +} diff --git a/tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj b/tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj index 28cf5a5c..31539b79 100644 --- a/tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj +++ b/tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj @@ -71,6 +71,7 @@ + @@ -81,6 +82,8 @@ + +