-
Notifications
You must be signed in to change notification settings - Fork 535
test: Gradient descent optimizer tests #1184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2121079
c32d153
2a377e2
f7b8dba
c906f46
b1972a8
2a5d0f7
e6c7c79
149caae
2cb5fd6
09d466d
c5b4928
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
using System.Linq; | ||
using Tensorflow; | ||
using static Tensorflow.Binding; | ||
using System.Collections.Generic; | ||
|
||
namespace TensorFlowNET.UnitTest | ||
{ | ||
|
@@ -144,6 +145,40 @@ public void assertAllClose(double value, NDArray array2, double eps = 1e-5) | |
Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | ||
} | ||
|
||
private class CollectionComparer : IComparer | ||
{ | ||
private readonly double _epsilon; | ||
|
||
public CollectionComparer(double eps = 1e-06) | ||
{ | ||
_epsilon = eps; | ||
} | ||
public int Compare(object x, object y) | ||
{ | ||
var a = (double)x; | ||
var b = (double)y; | ||
|
||
double delta = Math.Abs(a - b); | ||
if (delta < _epsilon) | ||
{ | ||
return 0; | ||
} | ||
return a.CompareTo(b); | ||
} | ||
} | ||
|
||
public void assertAllCloseAccordingToType<T>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wanglongzhi2001 Does it make sense not to duplicate all assertion code but combine all in one assembly? |
||
ICollection expected, | ||
ICollection<T> given, | ||
double eps = 1e-6, | ||
float float_eps = 1e-6f) | ||
{ | ||
// TODO: check if any of arguments is not double and change toletance | ||
// remove givenAsDouble and cast expected instead | ||
var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray(); | ||
CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps)); | ||
} | ||
|
||
public void assertProtoEquals(object toProto, object o) | ||
{ | ||
throw new NotImplementedException(); | ||
|
@@ -153,6 +188,20 @@ public void assertProtoEquals(object toProto, object o) | |
|
||
#region tensor evaluation and test session | ||
|
||
private Session _cached_session = null; | ||
private Graph _cached_graph = null; | ||
private object _cached_config = null; | ||
private bool _cached_force_gpu = false; | ||
|
||
private void _ClearCachedSession() | ||
{ | ||
if (self._cached_session != null) | ||
{ | ||
self._cached_session.Dispose(); | ||
self._cached_session = null; | ||
} | ||
} | ||
|
||
//protected object _eval_helper(Tensor[] tensors) | ||
//{ | ||
// if (tensors == null) | ||
|
@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor) | |
// return self._eval_helper(tensors) | ||
// else: | ||
{ | ||
var sess = tf.Session(); | ||
novikov-alexander marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var sess = tf.get_default_session(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wanglongzhi2001 The same functions exist in another assembly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, if you would like to. |
||
var ndarray = tensor.eval(sess); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wanglongzhi2001 what is the reason There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's to map the tensor to a node in the session (graph). |
||
if (typeof(T) == typeof(double)) | ||
if (typeof(T) == typeof(double) | ||
|| typeof(T) == typeof(float) | ||
|| typeof(T) == typeof(int)) | ||
{ | ||
result = Convert.ChangeType(ndarray, typeof(T)); | ||
} | ||
else if (typeof(T) == typeof(double[])) | ||
{ | ||
result = ndarray.ToMultiDimArray<double>(); | ||
} | ||
else if (typeof(T) == typeof(float[])) | ||
{ | ||
double x = ndarray; | ||
result = x; | ||
result = ndarray.ToMultiDimArray<float>(); | ||
} | ||
else if (typeof(T) == typeof(int)) | ||
else if (typeof(T) == typeof(int[])) | ||
{ | ||
int x = ndarray; | ||
result = x; | ||
result = ndarray.ToMultiDimArray<int>(); | ||
} | ||
else | ||
{ | ||
|
@@ -218,9 +275,56 @@ public T evaluate<T>(Tensor tensor) | |
} | ||
|
||
|
||
public Session cached_session() | ||
///Returns a TensorFlow Session for use in executing tests. | ||
public Session cached_session( | ||
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) | ||
{ | ||
throw new NotImplementedException(); | ||
// This method behaves differently than self.session(): for performance reasons | ||
// `cached_session` will by default reuse the same session within the same | ||
// test.The session returned by this function will only be closed at the end | ||
// of the test(in the TearDown function). | ||
|
||
// Use the `use_gpu` and `force_gpu` options to control where ops are run.If | ||
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if | ||
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as | ||
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to | ||
// the CPU. | ||
|
||
// Example: | ||
// python | ||
// class MyOperatorTest(test_util.TensorFlowTestCase) : | ||
// def testMyOperator(self): | ||
// with self.cached_session() as sess: | ||
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] | ||
// result = MyOperator(valid_input).eval() | ||
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] | ||
// invalid_input = [-1.0, 2.0, 7.0] | ||
// with self.assertRaisesOpError("negative input not supported"): | ||
// MyOperator(invalid_input).eval() | ||
|
||
|
||
// Args: | ||
// graph: Optional graph to use during the returned session. | ||
// config: An optional config_pb2.ConfigProto to use to configure the | ||
// session. | ||
// use_gpu: If True, attempt to run as many ops as possible on GPU. | ||
// force_gpu: If True, pin all ops to `/device:GPU:0`. | ||
|
||
// Yields: | ||
// A Session object that should be used as a context manager to surround | ||
// the graph building and execution code in a test case. | ||
|
||
|
||
// TODO: | ||
// if context.executing_eagerly(): | ||
// return self._eval_helper(tensors) | ||
// else: | ||
{ | ||
var sess = self._get_cached_session( | ||
graph, config, force_gpu, crash_if_inconsistent_args: true); | ||
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); | ||
return cached; | ||
} | ||
} | ||
|
||
//Returns a TensorFlow Session for use in executing tests. | ||
|
@@ -268,6 +372,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu = | |
return s.as_default(); | ||
} | ||
|
||
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) | ||
{ | ||
// Set the session and its graph to global default and constrain devices.""" | ||
if (tf.executing_eagerly()) | ||
return null; | ||
else | ||
{ | ||
sess.graph.as_default(); | ||
sess.as_default(); | ||
{ | ||
if (force_gpu) | ||
{ | ||
// TODO: | ||
|
||
// Use the name of an actual device if one is detected, or | ||
// '/device:GPU:0' otherwise | ||
/* var gpu_name = gpu_device_name(); | ||
if (!gpu_name) | ||
gpu_name = "/device:GPU:0" | ||
using (sess.graph.device(gpu_name)) { | ||
yield return sess; | ||
}*/ | ||
return sess; | ||
} | ||
else if (use_gpu) | ||
return sess; | ||
else | ||
using (sess.graph.device("/device:CPU:0")) | ||
return sess; | ||
} | ||
|
||
} | ||
} | ||
|
||
// See session() for details. | ||
private Session _create_session(Graph graph, object cfg, bool forceGpu) | ||
{ | ||
|
@@ -312,6 +450,54 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu) | |
return new Session(graph);//, config = prepare_config(config)) | ||
} | ||
|
||
private Session _get_cached_session( | ||
Graph graph = null, | ||
object config = null, | ||
bool force_gpu = false, | ||
bool crash_if_inconsistent_args = true) | ||
{ | ||
// See cached_session() for documentation. | ||
if (self._cached_session == null) | ||
{ | ||
var sess = self._create_session(graph, config, force_gpu); | ||
self._cached_session = sess; | ||
self._cached_graph = graph; | ||
self._cached_config = config; | ||
self._cached_force_gpu = force_gpu; | ||
return sess; | ||
} | ||
else | ||
{ | ||
|
||
if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph)) | ||
throw new ValueError(@"The graph used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
self.session()"); | ||
if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config)) | ||
{ | ||
throw new ValueError(@"The config used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wanglongzhi2001 this code is also copy pasted from my previous PR because that's how test architecture implemented right now. Should I move it to one common assembly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think moving to a common namespace is a good idea if it could be re-used by other modules. Please open an another PR if you would like to do that. |
||
self.session()"); | ||
} | ||
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) | ||
{ | ||
throw new ValueError(@"The force_gpu value used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
self.session()"); | ||
} | ||
return _cached_session; | ||
} | ||
} | ||
|
||
[TestCleanup] | ||
public void Cleanup() | ||
{ | ||
_ClearCachedSession(); | ||
} | ||
|
||
#endregion | ||
|
||
public void AssetSequenceEqual<T>(T[] a, T[] b) | ||
|
Uh oh!
There was an error while loading. Please reload this page.