diff --git a/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index 5ac78f35cdcee0..9e36c087c0bbb0 100644 --- a/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/coreclr/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -9,6 +9,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; using System.Runtime.Loader; using System.Runtime.Versioning; @@ -706,9 +707,14 @@ internal sealed class LicenseInteropProxy private const string LicenseRefTypeName = "System.ComponentModel.License&, System.ComponentModel.TypeConverter"; private const string LicInfoHelperLicenseContextTypeName = "System.ComponentModel.LicenseManager+LicInfoHelperLicenseContext, System.ComponentModel.TypeConverter"; - // RCW Activation - private object? _licContext; - private Type? _targetRcwType; + private readonly object _licContext; + private readonly Type _targetRcwType; + + private LicenseInteropProxy(object licContext, Type targetRcwType) + { + _licContext = licContext; + _targetRcwType = targetRcwType; + } [UnsafeAccessor(UnsafeAccessorKind.Method)] private static extern void SetSavedLicenseKey( @@ -734,7 +740,7 @@ private static extern bool ValidateAndRetrieveLicenseDetails( [UnsafeAccessor(UnsafeAccessorKind.StaticMethod)] [return: UnsafeAccessorType(LicenseContextTypeName)] - private static extern object? GetCurrentContextInfo( + private static extern object GetCurrentContextInfo( [UnsafeAccessorType(LicenseInteropHelperTypeName)] object? licInteropHelper, Type type, out bool isDesignTime, @@ -762,12 +768,6 @@ private static extern bool Contains( [UnsafeAccessorType(LicInfoHelperLicenseContextTypeName)] object? licInfoHelperContext, string assemblyName); - // Helper function to create an object from the native side - public static object Create() - { - return new LicenseInteropProxy(); - } - // Determine if the type supports licensing public static bool HasLicense(Type type) { @@ -866,31 +866,42 @@ public static object AllocateAndValidateLicense([DynamicallyAccessedMembers(Dyna } } - // See usage in native RCW code - public void GetCurrentContextInfo(RuntimeTypeHandle rth, out bool isDesignTime, out IntPtr bstrKey) + [UnmanagedCallersOnly] + private static unsafe void GetCurrentContextInfoAndProxy(MethodTable* pMT, bool* pIsDesignTime, ushort** pBstrKey, object* pProxy, Exception* pException) { - Type targetRcwTypeMaybe = Type.GetTypeFromHandle(rth)!; - - _licContext = GetCurrentContextInfo(null, targetRcwTypeMaybe, out isDesignTime, out string? key); + try + { + RuntimeType targetRcwTypeMaybe = RuntimeTypeHandle.GetRuntimeType(pMT); + object licContext = GetCurrentContextInfo(null, targetRcwTypeMaybe, out *pIsDesignTime, out string? key); - _targetRcwType = targetRcwTypeMaybe; - bstrKey = Marshal.StringToBSTR((string)key!); + *pBstrKey = BStrStringMarshaller.ConvertToUnmanaged(key); + *pProxy = new LicenseInteropProxy(licContext, targetRcwTypeMaybe); + } + catch (Exception ex) + { + *pException = ex; + } } // The CLR invokes this when instantiating a licensed COM - // object inside a designtime license context. + // object inside a design-time license context. // It's purpose is to save away the license key that the CLR // retrieved using RequestLicKey(). - public void SaveKeyInCurrentContext(IntPtr bstrKey) + [UnmanagedCallersOnly] + private static unsafe void SaveKeyInCurrentContext(LicenseInteropProxy* pProxy, ushort* bstrKey, Exception* pException) { - if (bstrKey == IntPtr.Zero) + try { - return; + string? key = BStrStringMarshaller.ConvertToManaged(bstrKey); + if (key is not null) + { + SetSavedLicenseKey(pProxy->_licContext, pProxy->_targetRcwType, key); + } + } + catch (Exception ex) + { + *pException = ex; } - - string key = Marshal.PtrToStringBSTR(bstrKey); - - SetSavedLicenseKey(_licContext!, _targetRcwType!, key); } } } diff --git a/src/coreclr/inc/holder.h b/src/coreclr/inc/holder.h index 026173b372b3ce..a32c622710278f 100644 --- a/src/coreclr/inc/holder.h +++ b/src/coreclr/inc/holder.h @@ -1244,6 +1244,58 @@ class HKEYHolder }; #endif // HOST_WINDOWS +#ifdef FEATURE_COMINTEROP +class BSTRHolder final +{ + BSTR m_str; +public: + BSTRHolder() + : m_str{} + { + STATIC_CONTRACT_LEAF; + } + explicit BSTRHolder(BSTR str) + : m_str{ str } + { + STATIC_CONTRACT_LEAF; + } + ~BSTRHolder() noexcept + { + STATIC_CONTRACT_WRAPPER; + Free(); + } + + BSTRHolder(const BSTRHolder&) = delete; + BSTRHolder& operator=(const BSTRHolder&) = delete; + BSTRHolder(BSTRHolder&&) = delete; + BSTRHolder& operator=(BSTRHolder&&) = delete; + + void Free() + { + STATIC_CONTRACT_WRAPPER; + ::SysFreeString(m_str); + m_str = NULL; + } + + void Attach(BSTR str) + { + STATIC_CONTRACT_WRAPPER; + Free(); + _ASSERTE(m_str == NULL); + m_str = str; + } + + BSTR* operator&() + { + STATIC_CONTRACT_LEAF; + _ASSERTE(m_str == NULL); + return &m_str; + } + + operator BSTR() const { STATIC_CONTRACT_LEAF; return m_str; } +}; +#endif // FEATURE_COMINTEROP + //---------------------------------------------------------------------------- // // External data access does not want certain holder implementations diff --git a/src/coreclr/inc/utilcode.h b/src/coreclr/inc/utilcode.h index 5f433ef7010bea..d074869642da63 100644 --- a/src/coreclr/inc/utilcode.h +++ b/src/coreclr/inc/utilcode.h @@ -3453,12 +3453,6 @@ namespace util INDEBUG(BOOL DbgIsExecutable(LPVOID lpMem, SIZE_T length);) -#ifdef FEATURE_COMINTEROP -FORCEINLINE void HolderSysFreeString(BSTR str) { CONTRACT_VIOLATION(ThrowsViolation); SysFreeString(str); } - -typedef Wrapper BSTRHolder; -#endif - BOOL IsIPInModule(PTR_VOID pModuleBaseAddress, PCODE ip); namespace UtilCode diff --git a/src/coreclr/utilcode/comex.cpp b/src/coreclr/utilcode/comex.cpp index 83785fbf46b221..0e6321242323ac 100644 --- a/src/coreclr/utilcode/comex.cpp +++ b/src/coreclr/utilcode/comex.cpp @@ -44,7 +44,7 @@ void COMException::GetMessage(SString &string) if (m_pErrorInfo != NULL) { - BSTRHolder message(NULL); + BSTRHolder message; if (SUCCEEDED(m_pErrorInfo->GetDescription(&message))) string.Set(message, SysStringLen(message)); } diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index c0f18baa3aeb7b..605ebd168b148c 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -185,9 +185,8 @@ DEFINE_METHOD(COM_OBJECT, CTOR, .ctor, #endif // FOR_ILLINK DEFINE_CLASS(LICENSE_INTEROP_PROXY, InternalInteropServices, LicenseInteropProxy) -DEFINE_METHOD(LICENSE_INTEROP_PROXY, CREATE, Create, SM_RetObj) -DEFINE_METHOD(LICENSE_INTEROP_PROXY, GETCURRENTCONTEXTINFO, GetCurrentContextInfo, IM_RuntimeTypeHandle_RefBool_RefIntPtr_RetVoid) -DEFINE_METHOD(LICENSE_INTEROP_PROXY, SAVEKEYINCURRENTCONTEXT, SaveKeyInCurrentContext, IM_IntPtr_RetVoid) +DEFINE_METHOD(LICENSE_INTEROP_PROXY, GETCURRENTCONTEXTINFO_AND_PROXY, GetCurrentContextInfoAndProxy, NoSig) +DEFINE_METHOD(LICENSE_INTEROP_PROXY, SAVEKEYINCURRENTCONTEXT, SaveKeyInCurrentContext, NoSig) #endif // FEATURE_COMINTEROP END_ILLINK_FEATURE_SWITCH() diff --git a/src/coreclr/vm/interoputil.cpp b/src/coreclr/vm/interoputil.cpp index 7d54ffe15a3831..2bb31168ea1523 100644 --- a/src/coreclr/vm/interoputil.cpp +++ b/src/coreclr/vm/interoputil.cpp @@ -3323,7 +3323,7 @@ void IUInvokeDispMethod( // We managed to retrieve an IDispatchEx IP so we will use it to // retrieve the DISPID. - BSTRHolder bstrTmpName = SysAllocString(aNamesToConvert[0]); + BSTRHolder bstrTmpName{ SysAllocString(aNamesToConvert[0]) }; if (!bstrTmpName) COMPlusThrowOM(); diff --git a/src/coreclr/vm/metasig.h b/src/coreclr/vm/metasig.h index 2a53071e05d781..ebb56e70d3c583 100644 --- a/src/coreclr/vm/metasig.h +++ b/src/coreclr/vm/metasig.h @@ -287,15 +287,6 @@ DEFINE_METASIG(SM(Obj_RetVoid, j, v)) DEFINE_METASIG(SM(Obj_RetInt, j, i)) DEFINE_METASIG(SM(Obj_RetIntPtr, j, I)) -#ifdef FEATURE_COMINTEROP -DEFINE_METASIG_T(SM(Obj_Int_RefComVariant_RetVoid, j i r(g(COMVARIANT)), v)) -DEFINE_METASIG_T(SM(Obj_RefComVariant_RetVoid, j r(g(COMVARIANT)), v)) -DEFINE_METASIG_T(SM(RefComVariant_RetObject, r(g(COMVARIANT)), j)) -DEFINE_METASIG_T(IM(RuntimeTypeHandle_RefBool_RefIntPtr_RetVoid, g(RT_TYPE_HANDLE) r(F) r(I), v)) - -#endif - - DEFINE_METASIG(SM(Str_RetInt, s, i)) DEFINE_METASIG(SM(Int_Str_RetIntPtr, i s, I)) DEFINE_METASIG(SM(Int_Str_IntPtr_RetIntPtr, i s I, I)) @@ -313,7 +304,6 @@ DEFINE_METASIG_T(IM(RetModule, _, C(MODULE))) DEFINE_METASIG_T(IM(PtrNativeAssemblyNameParts, P(g(NATIVE_ASSEMBLY_NAME_PARTS)), v)) DEFINE_METASIG(SM(PtrCharPtrVoid, P(u) P(v), v)) DEFINE_METASIG(IM(RetObj, _, j)) -DEFINE_METASIG(SM(RetObj, _, j)) DEFINE_METASIG(IM(RetStr, _, s)) DEFINE_METASIG_T(IM(RetType, _, C(TYPE))) diff --git a/src/coreclr/vm/runtimecallablewrapper.cpp b/src/coreclr/vm/runtimecallablewrapper.cpp index 3993c5293f138b..6c028b2108c14c 100644 --- a/src/coreclr/vm/runtimecallablewrapper.cpp +++ b/src/coreclr/vm/runtimecallablewrapper.cpp @@ -99,7 +99,7 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF HRESULT hr = S_OK; SafeComHolder pClassFact2 = NULL; SafeComHolder pUnk = NULL; - BSTRHolder bstrKey = NULL; + BSTRHolder bstrKey; // If the class doesn't support licensing or if it is missing a managed // type to use for querying a license, just use IClassFactory. @@ -121,40 +121,25 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF } else { - _ASSERTE(m_pClassMT != NULL); - // Get the type to query for licensing. - TypeHandle rth = TypeHandle(m_pClassMT); + _ASSERTE(m_pClassMT != NULL); struct { OBJECTREF pProxy; - OBJECTREF pType; } gc; - gc.pProxy = NULL; // LicenseInteropProxy - gc.pType = NULL; - + gc.pProxy = NULL; GCPROTECT_BEGIN(gc); - // Create an instance of the object - MethodDescCallSite createObj(METHOD__LICENSE_INTEROP_PROXY__CREATE); - gc.pProxy = createObj.Call_RetOBJECTREF(NULL); - gc.pType = rth.GetManagedClassObject(); + // Create instance and query the current licensing context + UnmanagedCallersOnlyCaller getCurrentContextInfoAndProxy(METHOD__LICENSE_INTEROP_PROXY__GETCURRENTCONTEXTINFO_AND_PROXY); - // Query the current licensing context - MethodDescCallSite getCurrentContextInfo(METHOD__LICENSE_INTEROP_PROXY__GETCURRENTCONTEXTINFO, &gc.pProxy); CLR_BOOL fDesignTime = FALSE; - ARG_SLOT args[4]; - args[0] = ObjToArgSlot(gc.pProxy); - args[1] = ObjToArgSlot(gc.pType); - args[2] = (ARG_SLOT)&fDesignTime; - args[3] = (ARG_SLOT)(BSTR*)&bstrKey; - - getCurrentContextInfo.Call(args); + getCurrentContextInfoAndProxy.InvokeThrowing(m_pClassMT, &fDesignTime, &bstrKey, &gc.pProxy); if (fDesignTime) { - // If designtime, we're supposed to obtain the runtime license key + // If design-time, we're supposed to obtain the runtime license key // from the component and save it away in the license context. // (the design tool can then grab it and embedded it into the // app it is creating) @@ -163,9 +148,7 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF // It's illegal for our helper to return a non-null bstrKey // when the context is design-time. But we'll try to do the // right thing anyway. - _ASSERTE(!"We're not supposed to get here, but we'll try to cope anyway."); - SysFreeString(bstrKey); - bstrKey = NULL; + bstrKey.Free(); } { @@ -181,11 +164,8 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF // Store the requested license key if (SUCCEEDED(hr)) { - MethodDescCallSite saveKeyInCurrentContext(METHOD__LICENSE_INTEROP_PROXY__SAVEKEYINCURRENTCONTEXT, &gc.pProxy); - - args[0] = ObjToArgSlot(gc.pProxy); - args[1] = (ARG_SLOT)(BSTR)bstrKey; - saveKeyInCurrentContext.Call(args); + UnmanagedCallersOnlyCaller saveKeyInCurrentContext(METHOD__LICENSE_INTEROP_PROXY__SAVEKEYINCURRENTCONTEXT); + saveKeyInCurrentContext.InvokeThrowing(&gc.pProxy, (BSTR)bstrKey); } } @@ -210,7 +190,6 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF else { // It is runtime and we have a license key. - _ASSERTE(bstrKey != NULL); hr = pClassFact2->CreateInstanceLic(punkOuter, NULL, IID_IUnknown, bstrKey, (void**)&pUnk); if (FAILED(hr) && punkOuter) { diff --git a/src/tests/Interop/COM/NETClients/Licensing/Program.cs b/src/tests/Interop/COM/NETClients/Licensing/Program.cs index e8f21273b4f76d..13bce04d9eff9b 100644 --- a/src/tests/Interop/COM/NETClients/Licensing/Program.cs +++ b/src/tests/Interop/COM/NETClients/Licensing/Program.cs @@ -75,9 +75,9 @@ public override void SetSavedLicenseKey(Type type, string key) } } - static void ActivateUnderDesigntimeContext() + static void ActivateUnderDesignTimeContext() { - Console.WriteLine($"Calling {nameof(ActivateUnderDesigntimeContext)}..."); + Console.WriteLine($"Calling {nameof(ActivateUnderDesignTimeContext)}..."); LicenseContext prev = LicenseManager.CurrentContext; try @@ -134,7 +134,7 @@ public static int TestEntryPoint() try { ActivateLicensedObject(); - ActivateUnderDesigntimeContext(); + ActivateUnderDesignTimeContext(); ActivateUnderRuntimeContext(); } catch (Exception e) diff --git a/src/tests/Interop/COM/NativeServer/LicenseTesting.h b/src/tests/Interop/COM/NativeServer/LicenseTesting.h index a67676c38c7f25..013089f3fe2323 100644 --- a/src/tests/Interop/COM/NativeServer/LicenseTesting.h +++ b/src/tests/Interop/COM/NativeServer/LicenseTesting.h @@ -28,10 +28,13 @@ class LicenseTesting : public UnknownImpl, public ILicenseTesting public: LicenseTesting(_In_opt_ BSTR lic) - : _lic{ lic } + : _lic{ TP_SysAllocString(lic) } { if (s_DenyLicense) + { + CoreClrBStrFree(_lic); throw CLASS_E_NOTLICENSED; + } } ~LicenseTesting()