Skip to content

Commit 2a377e2

Browse files
tests are passing
1 parent c32d153 commit 2a377e2

File tree

3 files changed

+46
-35
lines changed

3 files changed

+46
-35
lines changed

src/TensorFlowNET.Core/Variables/variables.py.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,5 @@ public static Operation _safe_initial_value_from_op(string name, Operation op, D
154154

155155
return op;
156156
}
157-
158-
public static Tensor global_variables_initializer()
159-
{
160-
// if context.executing_eagerly():
161-
// return control_flow_ops.no_op(name = "global_variables_initializer")
162-
var group = variables_initializer(global_variables().ToArray());
163-
return group;
164-
}
165157
}
166158
}

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using Tensorflow;
88
using static Tensorflow.Binding;
9+
using System.Collections.Generic;
910

1011
namespace TensorFlowNET.UnitTest
1112
{
@@ -144,11 +145,12 @@ public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
144145
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
145146
}
146147

147-
private class CollectionComparer : System.Collections.IComparer
148+
private class CollectionComparer : IComparer
148149
{
149150
private readonly double _epsilon;
150151

151-
public CollectionComparer(double eps = 1e-06) {
152+
public CollectionComparer(double eps = 1e-06)
153+
{
152154
_epsilon = eps;
153155
}
154156
public int Compare(object x, object y)
@@ -166,13 +168,15 @@ public int Compare(object x, object y)
166168
}
167169

168170
public void assertAllCloseAccordingToType<T>(
169-
T[] expected,
170-
T[] given,
171+
ICollection expected,
172+
ICollection<T> given,
171173
double eps = 1e-6,
172174
float float_eps = 1e-6f)
173175
{
174176
// TODO: check if any of arguments is not double and change toletance
175-
CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps));
177+
// remove givenAsDouble and cast expected instead
178+
var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray();
179+
CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps));
176180
}
177181

178182
public void assertProtoEquals(object toProto, object o)
@@ -241,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
241245
// return self._eval_helper(tensors)
242246
// else:
243247
{
244-
var sess = tf.Session();
248+
var sess = tf.get_default_session();
245249
var ndarray = tensor.eval(sess);
246-
if (typeof(T) == typeof(double))
250+
if (typeof(T) == typeof(double)
251+
|| typeof(T) == typeof(float)
252+
|| typeof(T) == typeof(int))
253+
{
254+
result = Convert.ChangeType(ndarray, typeof(T));
255+
}
256+
else if (typeof(T) == typeof(double[]))
257+
{
258+
result = ndarray.ToMultiDimArray<double>();
259+
}
260+
else if (typeof(T) == typeof(float[]))
247261
{
248-
double x = ndarray;
249-
result = x;
262+
result = ndarray.ToMultiDimArray<float>();
250263
}
251-
else if (typeof(T) == typeof(int))
264+
else if (typeof(T) == typeof(int[]))
252265
{
253-
int x = ndarray;
254-
result = x;
266+
result = ndarray.ToMultiDimArray<int>();
255267
}
256268
else
257269
{
@@ -457,12 +469,12 @@ private Session _get_cached_session(
457469
else
458470
{
459471

460-
if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
472+
if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph))
461473
throw new ValueError(@"The graph used to get the cached session is
462474
different than the one that was used to create the
463475
session. Maybe create a new session with
464476
self.session()");
465-
if (crash_if_inconsistent_args && !self._cached_config.Equals(config))
477+
if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config))
466478
{
467479
throw new ValueError(@"The config used to get the cached session is
468480
different than the one that was used to create the

test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
33
using System.Linq;
4-
using System.Runtime.Intrinsics.X86;
5-
using System.Security.AccessControl;
64
using Tensorflow.NumPy;
75
using TensorFlowNET.UnitTest;
86
using static Tensorflow.Binding;
@@ -12,18 +10,23 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
1210
[TestClass]
1311
public class GradientDescentOptimizerTest : PythonTest
1412
{
15-
private void TestBasicGeneric<T>() where T : struct
13+
private static TF_DataType GetTypeForNumericType<T>() where T : struct
1614
{
17-
var dtype = Type.GetTypeCode(typeof(T)) switch
15+
return Type.GetTypeCode(typeof(T)) switch
1816
{
1917
TypeCode.Single => np.float32,
2018
TypeCode.Double => np.float64,
2119
_ => throw new NotImplementedException(),
2220
};
21+
}
22+
23+
private void TestBasicGeneric<T>() where T : struct
24+
{
25+
var dtype = GetTypeForNumericType<T>();
2326

2427
// train.GradientDescentOptimizer is V1 only API.
2528
tf.Graph().as_default();
26-
using (self.cached_session())
29+
using (var sess = self.cached_session())
2730
{
2831
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
2932
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
@@ -36,21 +39,25 @@ private void TestBasicGeneric<T>() where T : struct
3639
};
3740
var sgd_op = optimizer.apply_gradients(grads_and_vars);
3841

39-
var global_variables = variables.global_variables_initializer();
40-
self.evaluate<T>(global_variables);
42+
var global_variables = tf.global_variables_initializer();
43+
sess.run(global_variables);
44+
4145
// Fetch params to validate initial values
46+
var initialVar0 = sess.run(var0);
47+
var valu = var0.eval(sess);
48+
var initialVar1 = sess.run(var1);
4249
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
43-
self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0));
44-
self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1));
50+
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
51+
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
4552
// Run 1 step of sgd
4653
sgd_op.run();
4754
// Validate updated params
4855
self.assertAllCloseAccordingToType(
49-
new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
50-
self.evaluate<double[]>(var0));
56+
new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
57+
self.evaluate<T[]>(var0));
5158
self.assertAllCloseAccordingToType(
52-
new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
53-
self.evaluate<double[]>(var1));
59+
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
60+
self.evaluate<T[]>(var1));
5461
// TODO: self.assertEqual(0, len(optimizer.variables()));
5562
}
5663
}

0 commit comments

Comments
 (0)