diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp
index 3bd5be1351ff10..a04559c8ea9030 100644
--- a/src/coreclr/vm/comcallablewrapper.cpp
+++ b/src/coreclr/vm/comcallablewrapper.cpp
@@ -4206,6 +4206,69 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForBasic(MethodTable
RETURN pComMT;
}
+//--------------------------------------------------------------------------
+// Returns TRUE if the parent's ComMethodTable for pItfMT can be reused for
+// pClassMT. This requires that no class between pClassMT and pParentMT has
+// re-implemented pItfMT in its dispatch map, and that the interface methods
+// resolve to the same MethodDescs on both pClassMT and pParentMT.
+//--------------------------------------------------------------------------
+static bool CanShareComMethodTableWithParent(MethodTable* pClassMT, MethodTable* pParentMT, MethodTable* pItfMT)
+{
+ CONTRACTL
+ {
+ THROWS;
+ GC_TRIGGERS;
+ MODE_ANY;
+ PRECONDITION(pClassMT != NULL && !pClassMT->IsInterface());
+ PRECONDITION(pParentMT != NULL && !pParentMT->IsInterface());
+ PRECONDITION(pItfMT != NULL && pItfMT->IsInterface());
+ }
+ CONTRACTL_END;
+
+ // Check for explicit interface re-implementations in the dispatch map.
+ MethodTable* pMT = pClassMT;
+ do
+ {
+ DispatchMap::EncodedMapIterator mapIt(pMT);
+ for (; mapIt.IsValid(); mapIt.Next())
+ {
+ DispatchMapEntry *pEntry = mapIt.Entry();
+ if (pMT->DispatchMapTypeMatchesMethodTable(pEntry->GetTypeID(), pItfMT))
+ {
+ return false;
+ }
+ }
+
+ pMT = pMT->GetParentMethodTable();
+ _ASSERTE(pMT != NULL);
+ }
+ while (pMT != pParentMT);
+
+ // Check that interface methods resolve to the same MethodDescs on both
+ // this class and pParentMT. With the baked-in dispatch target model, the
+ // ComMethodTable stores the resolved MethodDesc at layout time, so the
+ // table can only be shared if the targets are identical.
+ for (unsigned i = 0; i < pItfMT->GetNumVirtuals(); i++)
+ {
+ MethodDesc *pItfMD = pItfMT->GetMethodDescForSlot_NoThrow(i);
+ _ASSERTE(pItfMD != NULL);
+
+ if (pItfMD->IsAsyncMethod())
+ continue;
+
+ DispatchSlot childSlot(pClassMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */));
+ DispatchSlot parentSlot(pParentMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */));
+
+ if (childSlot.IsNull() || parentSlot.IsNull())
+ return false;
+
+ if (childSlot.GetMethodDesc() != parentSlot.GetMethodDesc())
+ return false;
+ }
+
+ return true;
+}
+
//--------------------------------------------------------------------------
// Creates a ComMethodTable for an interface and stores it in the m_rgpIPtr array.
//--------------------------------------------------------------------------
@@ -4222,22 +4285,16 @@ ComMethodTable *ComCallWrapperTemplate::InitializeForInterface(MethodTable *pPar
ComMethodTable *pItfComMT = NULL;
if (m_pParent != NULL)
{
- pItfComMT = m_pParent->GetComMTForItf(pItfMT);
- if (pItfComMT != NULL)
+ // Check if we can reuse the parent's ComMethodTable for this interface.
+ ComMethodTable* pParentComMT = m_pParent->GetComMTForItf(pItfMT);
+ if (pParentComMT != NULL && CanShareComMethodTableWithParent(m_thClass.GetMethodTable(), pParentMT, pItfMT))
{
- // if the parent COM MT is not a trivial aggregate, simple MethodTable slot check is enough
- if (!m_thClass.GetMethodTable()->ImplementsInterfaceWithSameSlotsAsParent(pItfMT, pParentMT))
- {
- // the interface is implemented by parent but this class reimplemented
- // its method(s) so we will need to build a new COM vtable for it
- pItfComMT = NULL;
- }
+ pItfComMT = pParentComMT;
}
}
if (pItfComMT == NULL)
{
- // we couldn't use parent's vtable so we create a new one
pItfComMT = CreateComMethodTableForInterface(pItfMT);
}
diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp
index 50e628b85b0cec..935d282119e605 100644
--- a/src/coreclr/vm/methodtable.cpp
+++ b/src/coreclr/vm/methodtable.cpp
@@ -5977,39 +5977,6 @@ UINT32 MethodTable::LookupTypeID()
return AppDomain::GetCurrentDomain()->LookupTypeID(pMT);
}
-//==========================================================================================
-BOOL MethodTable::ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT)
-{
- CONTRACTL
- {
- THROWS;
- GC_TRIGGERS;
- PRECONDITION(!IsInterface() && !pParentMT->IsInterface());
- PRECONDITION(pItfMT->IsInterface());
- } CONTRACTL_END;
-
- MethodTable *pMT = this;
- do
- {
- DispatchMap::EncodedMapIterator it(pMT);
- for (; it.IsValid(); it.Next())
- {
- DispatchMapEntry *pCurEntry = it.Entry();
- if (DispatchMapTypeMatchesMethodTable(pCurEntry->GetTypeID(), pItfMT))
- {
- // this class and its parents up to pParentMT must have no mappings for the interface
- return FALSE;
- }
- }
-
- pMT = pMT->GetParentMethodTable();
- _ASSERTE(pMT != NULL);
- }
- while (pMT != pParentMT);
-
- return TRUE;
-}
-
#endif // !DACCESS_COMPILE
//==========================================================================================
diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h
index 3a7c7e0cf2b2f7..b72bfe7a7f0fda 100644
--- a/src/coreclr/vm/methodtable.h
+++ b/src/coreclr/vm/methodtable.h
@@ -2572,11 +2572,6 @@ class MethodTable
MethodTable *LookupDispatchMapType(DispatchMapTypeID typeID);
bool DispatchMapTypeMatchesMethodTable(DispatchMapTypeID typeID, MethodTable* pMT);
- // Determines whether all methods in the given interface have their final implementing
- // slot in a parent class. I.e. if this returns TRUE, it is trivial (no VSD lookup) to
- // dispatch pItfMT methods on this class if one knows how to dispatch them on pParentMT.
- BOOL ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT);
-
// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level);
diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs
new file mode 100644
index 00000000000000..67b643361e7a05
--- /dev/null
+++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs
@@ -0,0 +1,129 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using Xunit;
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000001")]
+public interface IFoo
+{
+ void DoWork();
+}
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000002")]
+[ComDefaultInterface(typeof(IFoo))]
+public class Foo : IFoo
+{
+ public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo);
+}
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000003")]
+[ComDefaultInterface(typeof(IFoo))]
+public class FooDerived : Foo
+{
+ public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived);
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000001")]
+public interface IBar
+{
+ void DoWork();
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000002")]
+[ComDefaultInterface(typeof(IBar))]
+public class Bar : IBar
+{
+ public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar);
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000003")]
+[ComDefaultInterface(typeof(IBar))]
+public class BarDerived : Bar
+{
+ public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived);
+}
+
+///
+/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides
+/// regardless of whether the base or derived class is accessed via COM first.
+///
+public class VirtualMethodOverrideTest
+{
+ internal static string? LastCalledType;
+
+ [UnmanagedFunctionPointer(CallingConvention.StdCall)]
+ delegate int DoWorkDelegate(IntPtr pThis);
+
+ private static int CallDoWork(IntPtr pInterface, int slot)
+ {
+ IntPtr vtbl = Marshal.ReadIntPtr(pInterface);
+ IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size);
+ Assert.NotEqual(IntPtr.Zero, fnPtr);
+
+ var fn = Marshal.GetDelegateForFunctionPointer(fnPtr);
+ return fn(pInterface);
+ }
+
+ [Fact]
+ public static void DerivedFirst()
+ {
+ int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo));
+ IntPtr pDerived = IntPtr.Zero;
+ IntPtr pBase = IntPtr.Zero;
+ try
+ {
+ pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo));
+ pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo));
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
+ Assert.Equal(nameof(FooDerived), LastCalledType);
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
+ Assert.Equal(nameof(Foo), LastCalledType);
+ }
+ finally
+ {
+ if (pDerived != IntPtr.Zero)
+ Marshal.Release(pDerived);
+
+ if (pBase != IntPtr.Zero)
+ Marshal.Release(pBase);
+ }
+ }
+
+ [Fact]
+ public static void BaseFirst()
+ {
+ int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar));
+ IntPtr pBase = IntPtr.Zero;
+ IntPtr pDerived = IntPtr.Zero;
+ try
+ {
+ pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar));
+ pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar));
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
+ Assert.Equal(nameof(Bar), LastCalledType);
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
+ Assert.Equal(nameof(BarDerived), LastCalledType);
+ }
+ finally
+ {
+ if (pBase != IntPtr.Zero) Marshal.Release(pBase);
+ if (pDerived != IntPtr.Zero) Marshal.Release(pDerived);
+ }
+ }
+}
diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj
new file mode 100644
index 00000000000000..412688de2f640c
--- /dev/null
+++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj
@@ -0,0 +1,10 @@
+
+
+ true
+ true
+ true
+
+
+
+
+