diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp index 3bd5be1351ff10..a04559c8ea9030 100644 --- a/src/coreclr/vm/comcallablewrapper.cpp +++ b/src/coreclr/vm/comcallablewrapper.cpp @@ -4206,6 +4206,69 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForBasic(MethodTable RETURN pComMT; } +//-------------------------------------------------------------------------- +// Returns TRUE if the parent's ComMethodTable for pItfMT can be reused for +// pClassMT. This requires that no class between pClassMT and pParentMT has +// re-implemented pItfMT in its dispatch map, and that the interface methods +// resolve to the same MethodDescs on both pClassMT and pParentMT. +//-------------------------------------------------------------------------- +static bool CanShareComMethodTableWithParent(MethodTable* pClassMT, MethodTable* pParentMT, MethodTable* pItfMT) +{ + CONTRACTL + { + THROWS; + GC_TRIGGERS; + MODE_ANY; + PRECONDITION(pClassMT != NULL && !pClassMT->IsInterface()); + PRECONDITION(pParentMT != NULL && !pParentMT->IsInterface()); + PRECONDITION(pItfMT != NULL && pItfMT->IsInterface()); + } + CONTRACTL_END; + + // Check for explicit interface re-implementations in the dispatch map. + MethodTable* pMT = pClassMT; + do + { + DispatchMap::EncodedMapIterator mapIt(pMT); + for (; mapIt.IsValid(); mapIt.Next()) + { + DispatchMapEntry *pEntry = mapIt.Entry(); + if (pMT->DispatchMapTypeMatchesMethodTable(pEntry->GetTypeID(), pItfMT)) + { + return false; + } + } + + pMT = pMT->GetParentMethodTable(); + _ASSERTE(pMT != NULL); + } + while (pMT != pParentMT); + + // Check that interface methods resolve to the same MethodDescs on both + // this class and pParentMT. With the baked-in dispatch target model, the + // ComMethodTable stores the resolved MethodDesc at layout time, so the + // table can only be shared if the targets are identical. + for (unsigned i = 0; i < pItfMT->GetNumVirtuals(); i++) + { + MethodDesc *pItfMD = pItfMT->GetMethodDescForSlot_NoThrow(i); + _ASSERTE(pItfMD != NULL); + + if (pItfMD->IsAsyncMethod()) + continue; + + DispatchSlot childSlot(pClassMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */)); + DispatchSlot parentSlot(pParentMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */)); + + if (childSlot.IsNull() || parentSlot.IsNull()) + return false; + + if (childSlot.GetMethodDesc() != parentSlot.GetMethodDesc()) + return false; + } + + return true; +} + //-------------------------------------------------------------------------- // Creates a ComMethodTable for an interface and stores it in the m_rgpIPtr array. //-------------------------------------------------------------------------- @@ -4222,22 +4285,16 @@ ComMethodTable *ComCallWrapperTemplate::InitializeForInterface(MethodTable *pPar ComMethodTable *pItfComMT = NULL; if (m_pParent != NULL) { - pItfComMT = m_pParent->GetComMTForItf(pItfMT); - if (pItfComMT != NULL) + // Check if we can reuse the parent's ComMethodTable for this interface. + ComMethodTable* pParentComMT = m_pParent->GetComMTForItf(pItfMT); + if (pParentComMT != NULL && CanShareComMethodTableWithParent(m_thClass.GetMethodTable(), pParentMT, pItfMT)) { - // if the parent COM MT is not a trivial aggregate, simple MethodTable slot check is enough - if (!m_thClass.GetMethodTable()->ImplementsInterfaceWithSameSlotsAsParent(pItfMT, pParentMT)) - { - // the interface is implemented by parent but this class reimplemented - // its method(s) so we will need to build a new COM vtable for it - pItfComMT = NULL; - } + pItfComMT = pParentComMT; } } if (pItfComMT == NULL) { - // we couldn't use parent's vtable so we create a new one pItfComMT = CreateComMethodTableForInterface(pItfMT); } diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index 50e628b85b0cec..935d282119e605 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -5977,39 +5977,6 @@ UINT32 MethodTable::LookupTypeID() return AppDomain::GetCurrentDomain()->LookupTypeID(pMT); } -//========================================================================================== -BOOL MethodTable::ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT) -{ - CONTRACTL - { - THROWS; - GC_TRIGGERS; - PRECONDITION(!IsInterface() && !pParentMT->IsInterface()); - PRECONDITION(pItfMT->IsInterface()); - } CONTRACTL_END; - - MethodTable *pMT = this; - do - { - DispatchMap::EncodedMapIterator it(pMT); - for (; it.IsValid(); it.Next()) - { - DispatchMapEntry *pCurEntry = it.Entry(); - if (DispatchMapTypeMatchesMethodTable(pCurEntry->GetTypeID(), pItfMT)) - { - // this class and its parents up to pParentMT must have no mappings for the interface - return FALSE; - } - } - - pMT = pMT->GetParentMethodTable(); - _ASSERTE(pMT != NULL); - } - while (pMT != pParentMT); - - return TRUE; -} - #endif // !DACCESS_COMPILE //========================================================================================== diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h index 3a7c7e0cf2b2f7..b72bfe7a7f0fda 100644 --- a/src/coreclr/vm/methodtable.h +++ b/src/coreclr/vm/methodtable.h @@ -2572,11 +2572,6 @@ class MethodTable MethodTable *LookupDispatchMapType(DispatchMapTypeID typeID); bool DispatchMapTypeMatchesMethodTable(DispatchMapTypeID typeID, MethodTable* pMT); - // Determines whether all methods in the given interface have their final implementing - // slot in a parent class. I.e. if this returns TRUE, it is trivial (no VSD lookup) to - // dispatch pItfMT methods on this class if one knows how to dispatch them on pParentMT. - BOOL ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT); - // Try to resolve a given static virtual method override on this type. Return nullptr // when not found. MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level); diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs new file mode 100644 index 00000000000000..67b643361e7a05 --- /dev/null +++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using Xunit; + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000001")] +public interface IFoo +{ + void DoWork(); +} + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000002")] +[ComDefaultInterface(typeof(IFoo))] +public class Foo : IFoo +{ + public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo); +} + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000003")] +[ComDefaultInterface(typeof(IFoo))] +public class FooDerived : Foo +{ + public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000001")] +public interface IBar +{ + void DoWork(); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000002")] +[ComDefaultInterface(typeof(IBar))] +public class Bar : IBar +{ + public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000003")] +[ComDefaultInterface(typeof(IBar))] +public class BarDerived : Bar +{ + public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived); +} + +/// +/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides +/// regardless of whether the base or derived class is accessed via COM first. +/// +public class VirtualMethodOverrideTest +{ + internal static string? LastCalledType; + + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + delegate int DoWorkDelegate(IntPtr pThis); + + private static int CallDoWork(IntPtr pInterface, int slot) + { + IntPtr vtbl = Marshal.ReadIntPtr(pInterface); + IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size); + Assert.NotEqual(IntPtr.Zero, fnPtr); + + var fn = Marshal.GetDelegateForFunctionPointer(fnPtr); + return fn(pInterface); + } + + [Fact] + public static void DerivedFirst() + { + int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo)); + IntPtr pDerived = IntPtr.Zero; + IntPtr pBase = IntPtr.Zero; + try + { + pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo)); + pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo)); + + LastCalledType = null; + Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0); + Assert.Equal(nameof(FooDerived), LastCalledType); + + LastCalledType = null; + Assert.True(CallDoWork(pBase, doWorkSlot) >= 0); + Assert.Equal(nameof(Foo), LastCalledType); + } + finally + { + if (pDerived != IntPtr.Zero) + Marshal.Release(pDerived); + + if (pBase != IntPtr.Zero) + Marshal.Release(pBase); + } + } + + [Fact] + public static void BaseFirst() + { + int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar)); + IntPtr pBase = IntPtr.Zero; + IntPtr pDerived = IntPtr.Zero; + try + { + pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar)); + pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar)); + + LastCalledType = null; + Assert.True(CallDoWork(pBase, doWorkSlot) >= 0); + Assert.Equal(nameof(Bar), LastCalledType); + + LastCalledType = null; + Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0); + Assert.Equal(nameof(BarDerived), LastCalledType); + } + finally + { + if (pBase != IntPtr.Zero) Marshal.Release(pBase); + if (pDerived != IntPtr.Zero) Marshal.Release(pDerived); + } + } +} diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj new file mode 100644 index 00000000000000..412688de2f640c --- /dev/null +++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj @@ -0,0 +1,10 @@ + + + true + true + true + + + + +