diff --git a/src/embed_tests/ClassManagerTests.cs b/src/embed_tests/ClassManagerTests.cs index 2fd38f272..264509c2a 100644 --- a/src/embed_tests/ClassManagerTests.cs +++ b/src/embed_tests/ClassManagerTests.cs @@ -1179,6 +1179,236 @@ def contains(dictionary, key): Assert.IsFalse(result); } + [Test] + public void SupportsLenOperatorForIEnumerableWithCountProperty() + { + using var _ = Py.GIL(); + + var module = PyModule.FromString("SupportsLenOperatorForIEnumerableWithCountProperty", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def length(enumerable): + return len(enumerable) +"); + + using var length = module.GetAttr("length"); + + Assert.Multiple(() => + { + var enumerableWithCount = new EnumerableWithCount(); + using var pyEnumerableWithCount = enumerableWithCount.ToPython(); + var count = length.Invoke(pyEnumerableWithCount).As(); + Assert.AreEqual(enumerableWithCount.Count, count); + + var genericEnumerableWithCount = new GenericEnumerableWithCount(); + using var pyGenericEnumerableWithCount = genericEnumerableWithCount.ToPython(); + count = length.Invoke(pyGenericEnumerableWithCount).As(); + Assert.AreEqual(genericEnumerableWithCount.Count, count); + + var derivedEnumerableWithCount = new DerivedEnumerableWithCount(); + using var pyDerivedEnumerableWithCount = derivedEnumerableWithCount.ToPython(); + count = length.Invoke(pyDerivedEnumerableWithCount).As(); + Assert.AreEqual(derivedEnumerableWithCount.Count, count); + }); + } + + private class EnumerableWithCount : IEnumerable + { + public int Count => 123; + public IEnumerator GetEnumerator() + { + for (int i = 0; i < Count; i++) + { + yield return i; + } + } + } + + private class GenericEnumerableWithCount : IEnumerable + { + public int Count => 123; + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < Count; i++) + { + yield return i; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + private class DerivedEnumerableWithCount : GenericEnumerableWithCount + { + } + + [Test] + public void SupportsLenOperatorForICollection() + { + using var _ = Py.GIL(); + + var module = PyModule.FromString("SupportsLenOperatorForICollection", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def length(enumerable): + return len(enumerable) +"); + + using var length = module.GetAttr("length"); + + Assert.Multiple(() => + { + var collection = new BasicCollection(); + using var pyCollection = collection.ToPython(); + var count = length.Invoke(pyCollection).As(); + Assert.AreEqual(collection.Count, count); + + var genericCollection = new GenericCollection(); + using var pyGenericCollection = genericCollection.ToPython(); + count = length.Invoke(pyGenericCollection).As(); + Assert.AreEqual(genericCollection.Count, count); + + var collectionWithExplicitInterfaceImplementation = new CollectionWithExplicitInterfaceImplementation(); + using var pyCollectionWithExplicitInterfaceImplementation = collectionWithExplicitInterfaceImplementation.ToPython(); + count = length.Invoke(pyCollectionWithExplicitInterfaceImplementation).As(); + Assert.AreEqual(((ICollection)collectionWithExplicitInterfaceImplementation).Count, count); + }); + } + + private class BasicCollection : ICollection + { + public int Count => 123; + public bool IsSynchronized => false; + public object SyncRoot => this; + public void CopyTo(Array array, int index) + { + for (int i = 0; i < Count; i++) + { + array.SetValue(i, index + i); + } + } + public IEnumerator GetEnumerator() + { + for (int i = 0; i < Count; i++) + { + yield return i; + } + } + } + + private class GenericCollection : ICollection + { + public int Count => 123; + public bool IsSynchronized => false; + public object SyncRoot => this; + + public bool IsReadOnly => throw new NotImplementedException(); + + public void Add(int item) + { + throw new NotImplementedException(); + } + + public void Clear() + { + throw new NotImplementedException(); + } + + public bool Contains(int item) + { + throw new NotImplementedException(); + } + + public void CopyTo(int[] array, int index) + { + for (int i = 0; i < Count; i++) + { + array[index + i] = i; + } + } + public IEnumerator GetEnumerator() + { + for (int i = 0; i < Count; i++) + { + yield return i; + } + } + + public bool Remove(int item) + { + throw new NotImplementedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + private class CollectionWithExplicitInterfaceImplementation : ICollection + { + public bool IsSynchronized => false; + public object SyncRoot => this; + + int ICollection.Count => 123; + + bool ICollection.IsReadOnly => true; + + void ICollection.CopyTo(int[] array, int index) + { + for (int i = 0; i < ((ICollection)this).Count; i++) + { + array[index + i] = i; + } + } + public IEnumerator GetEnumerator() + { + for (int i = 0; i < ((ICollection)this).Count; i++) + { + yield return i; + } + } + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + void ICollection.Add(int item) + { + throw new NotImplementedException(); + } + + void ICollection.Clear() + { + throw new NotImplementedException(); + } + + bool ICollection.Contains(int item) + { + throw new NotImplementedException(); + } + + bool ICollection.Remove(int item) + { + throw new NotImplementedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + } + public class TestDictionary : IDictionary { private readonly Dictionary _data = new(); diff --git a/src/perf_tests/Python.PerformanceTests.csproj b/src/perf_tests/Python.PerformanceTests.csproj index 210552748..17af4024c 100644 --- a/src/perf_tests/Python.PerformanceTests.csproj +++ b/src/perf_tests/Python.PerformanceTests.csproj @@ -13,7 +13,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + compile @@ -25,7 +25,7 @@ - + diff --git a/src/runtime/Properties/AssemblyInfo.cs b/src/runtime/Properties/AssemblyInfo.cs index b17e8cd57..06f73394d 100644 --- a/src/runtime/Properties/AssemblyInfo.cs +++ b/src/runtime/Properties/AssemblyInfo.cs @@ -4,5 +4,5 @@ [assembly: InternalsVisibleTo("Python.EmbeddingTest, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")] [assembly: InternalsVisibleTo("Python.Test, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")] -[assembly: AssemblyVersion("2.0.53")] -[assembly: AssemblyFileVersion("2.0.53")] +[assembly: AssemblyVersion("2.0.54")] +[assembly: AssemblyFileVersion("2.0.54")] diff --git a/src/runtime/Python.Runtime.csproj b/src/runtime/Python.Runtime.csproj index 981767b9e..953fdcba0 100644 --- a/src/runtime/Python.Runtime.csproj +++ b/src/runtime/Python.Runtime.csproj @@ -5,7 +5,7 @@ Python.Runtime Python.Runtime QuantConnect.pythonnet - 2.0.53 + 2.0.54 false LICENSE https://github.com/pythonnet/pythonnet diff --git a/src/runtime/Types/MpLengthSlot.cs b/src/runtime/Types/MpLengthSlot.cs index 9e4865fe0..479ee73b9 100644 --- a/src/runtime/Types/MpLengthSlot.cs +++ b/src/runtime/Types/MpLengthSlot.cs @@ -1,7 +1,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Reflection; @@ -9,20 +8,23 @@ namespace Python.Runtime.Slots { internal static class MpLengthSlot { + private static Dictionary _countGettersCache = new(); + public static bool CanAssign(Type clrType) { - if (typeof(ICollection).IsAssignableFrom(clrType)) + if (typeof(IEnumerable).IsAssignableFrom(clrType) && TryGetCountGetter(clrType, clrType, out _)) { return true; } - if (clrType.GetInterfaces().Any(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>))) - { - return true; - } - if (clrType.IsInterface && clrType.IsGenericType && clrType.GetGenericTypeDefinition() == typeof(ICollection<>)) + + var iface = clrType.GetInterfaces().FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>)); + if (iface != null) { + // Get and cache the Count getter for this type and interface + TryGetCountGetter(clrType, iface, out _); return true; } + return false; } @@ -46,24 +48,31 @@ internal static nint impl(BorrowedReference ob) } Type clrType = co.inst.GetType(); - - // now look for things that implement ICollection directly (non-explicitly) - PropertyInfo p = clrType.GetProperty("Count"); - if (p != null && clrType.GetInterfaces().Any(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>))) + if (TryGetCountGetter(clrType, clrType, out var getter)) { - return (int)p.GetValue(co.inst, null); + return (int)getter.Invoke(co.inst, null); } - // finally look for things that implement the interface explicitly - var iface = clrType.GetInterfaces().FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>)); - if (iface != null) + Exceptions.SetError(Exceptions.TypeError, $"object of type '{clrType.Name}' has no len()"); + return -1; + } + + /// + /// Will get the Count getter for the given parent type and cache it for the given clr type. + /// This allows us to cache the Count getter for the give type when it's defined as a private interface implementation. + /// + private static bool TryGetCountGetter(Type clrType, Type parentType, out MethodInfo getter) + { + if (!_countGettersCache.TryGetValue(clrType, out getter)) { - p = iface.GetProperty(nameof(ICollection.Count)); - return (int)p.GetValue(co.inst, null); + var countProperty = parentType.GetProperty("Count"); + if (countProperty != null) + { + _countGettersCache[clrType] = getter = countProperty.GetMethod; + } } - Exceptions.SetError(Exceptions.TypeError, $"object of type '{clrType.Name}' has no len()"); - return -1; + return getter != null; } } }