diff --git a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IList1.cs b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IList1.cs index 84a9ce8ff..e3620df85 100644 --- a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IList1.cs +++ b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IList1.cs @@ -982,6 +982,92 @@ public static void ImplType( ModuleDefinition module, out TypeDefinition implType) { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'GetAt' method + MethodDefinition getAtMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.GetAt( + readOnlyListType: listType, + getAtMethod: interopReferences.IListAdapter1GetAt(elementType), + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'get_Size' method + MethodDefinition sizeMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.get_Size( + readOnlyListType: listType, + sizeMethod: interopReferences.IListAdapter1Size(elementType), + interopReferences: interopReferences, + module: module); + + // Define the 'GetView' method + MethodDefinition getViewMethod = InteropMethodDefinitionFactory.IList1Impl.GetView( + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'IndexOf' method + MethodDefinition indexOfMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.IndexOf( + readOnlyListType: listType, + indexOfMethod: interopReferences.IListAdapter1IndexOf(elementType), + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'SetAt' method + MethodDefinition setAtMethod = InteropMethodDefinitionFactory.IList1Impl.SetAt( + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'InsertAt' method + MethodDefinition insertAtMethod = InteropMethodDefinitionFactory.IList1Impl.InsertAt( + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'RemoveAt' method + MethodDefinition removeAtMethod = InteropMethodDefinitionFactory.IList1Impl.RemoveAt( + listType: listType, + interopReferences: interopReferences, + module: module); + + // Define the 'Append' method + MethodDefinition appendMethod = InteropMethodDefinitionFactory.IList1Impl.Append( + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'RemoveAtEnd' method + MethodDefinition removeAtEndMethod = InteropMethodDefinitionFactory.IList1Impl.RemoveAtEnd( + listType: listType, + interopReferences: interopReferences, + module: module); + + // Define the 'Clear' method + MethodDefinition clearMethod = InteropMethodDefinitionFactory.IList1Impl.Clear( + listType: listType, + interopReferences: interopReferences, + module: module); + + // Define the 'GetMany' method + MethodDefinition getManyMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.GetMany( + readOnlyListType: listType, + getAtMethod: interopReferences.IListAdapter1GetAt(elementType), + interopReferences: interopReferences, + emitState: emitState, + module: module); + + // Define the 'ReplaceAll' method + MethodDefinition replaceAllMethod = InteropMethodDefinitionFactory.IList1Impl.ReplaceAll( + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + Impl( interfaceType: ComInterfaceType.InterfaceIsIInspectable, ns: InteropUtf8NameFactory.TypeNamespace(listType), @@ -991,7 +1077,19 @@ public static void ImplType( interopReferences: interopReferences, module: module, implType: out implType, - vtableMethods: []); + vtableMethods: [ + getAtMethod, + sizeMethod, + getViewMethod, + indexOfMethod, + setAtMethod, + insertAtMethod, + removeAtMethod, + appendMethod, + removeAtEndMethod, + clearMethod, + getManyMethod, + replaceAllMethod]); // Track the type (it may be needed by COM interface entries for user-defined types) emitState.TrackTypeDefinition(implType, listType, "Impl"); diff --git a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IReadOnlyList1.cs b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IReadOnlyList1.cs index ca474b1ef..13898edb1 100644 --- a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IReadOnlyList1.cs +++ b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.IReadOnlyList1.cs @@ -426,9 +426,12 @@ public static void ImplType( ModuleDefinition module, out TypeDefinition implType) { + TypeSignature elementType = readOnlyListType.TypeArguments[0]; + // Define the 'GetAt' method MethodDefinition getAtMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.GetAt( readOnlyListType: readOnlyListType, + getAtMethod: interopReferences.IReadOnlyListAdapter1GetAt(elementType), interopReferences: interopReferences, emitState: emitState, module: module); @@ -436,12 +439,14 @@ public static void ImplType( // Define the 'get_Size' method MethodDefinition sizeMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.get_Size( readOnlyListType: readOnlyListType, + sizeMethod: interopReferences.IReadOnlyListAdapter1Size(elementType), interopReferences: interopReferences, module: module); // Define the 'IndexOf' method MethodDefinition indexOfMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.IndexOf( readOnlyListType: readOnlyListType, + indexOfMethod: interopReferences.IReadOnlyListAdapter1IndexOf(elementType), interopReferences: interopReferences, emitState: emitState, module: module); @@ -449,6 +454,7 @@ public static void ImplType( // Define the 'GetMany' method MethodDefinition getManyMethod = InteropMethodDefinitionFactory.IReadOnlyList1Impl.GetMany( readOnlyListType: readOnlyListType, + getAtMethod: interopReferences.IReadOnlyListAdapter1GetAt(elementType), interopReferences: interopReferences, emitState: emitState, module: module); diff --git a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs index 6b8294d01..ba291ee87 100644 --- a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs +++ b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs @@ -226,6 +226,20 @@ private static void TryTrackWindowsRuntimeGenericInterfaceTypeInstance( else if (SignatureComparer.IgnoreVersion.Equals(typeSignature.GenericType, interopReferences.IList1)) { discoveryState.TrackIList1Type(typeSignature); + + // Whenever we find an 'IList' instantiation, we also need to track the corresponding 'IReadOnlyList' instantiation. + // This is because that interface is needed to marshal the return value of the 'IVector.GetView' method ('IVectorView'). + discoveryState.TrackIReadOnlyList1Type(interopReferences.IReadOnlyList1.MakeGenericReferenceType([.. typeSignature.TypeArguments])); + + // We also need to track the constructed 'ReadOnlyCollection' type, as that is used by 'IListAdapter.GetView' in case the + // input 'IList' instance doesn't implement 'IReadOnlyList' directly. In that case, we return a 'ReadOnlyCollection' + // object instead. This needs special handling because we won't analyze indirect (generated) calls into that adapter type. + TryTrackGenericTypeInstance( + typeSignature: interopReferences.ReadOnlyCollection1.MakeGenericReferenceType([.. typeSignature.TypeArguments]), + args: args, + discoveryState: discoveryState, + interopReferences: interopReferences, + module: module); } else if (SignatureComparer.IgnoreVersion.Equals(typeSignature.GenericType, interopReferences.IReadOnlyList1)) { diff --git a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs index 22d45aae9..6546c8fac 100644 --- a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs +++ b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs @@ -109,6 +109,24 @@ public static void TryTrackExposedUserDefinedType( return; } + // Check if this is the first time that this user-defined type has been seen, and stop immediately if not. + // If the type has been seen before, it means that it either has already been fully processed, or that it + // is currently being processed (possibly by another thread, if multi-threading discovery is enabled). The + // reason for this check is not so much to improve performance (although it does avoid some repeated work), + // but most importantly to avoid stack overlows due to infinite recursion in cases where user-defined types + // implement interfaces that then transitively required the same user-defined type to be tracked. + // + // For instance, consider a scenario where 'List' is being discovered. While processing the implemented + // interfaces, 'IList' will also be discovered. This will then require 'ReadOnlyCollection' to be + // tracked, because it is used by the fallback code for the CCW implementation method of 'IVector.GetView'. + // However, 'ReadOnlyCollection' itself will also implement 'IList', which would then require tracking + // 'ReadOnlyCollection' itself again too, etc. That would just recurse forever without this check, because + // all those interfaces would keep being discovered before the initial processing of the type has finished. + if (!discoveryState.TryMarkUserDefinedType(typeSignature)) + { + return; + } + // Reuse the thread-local builder to track all implemented interfaces for the current type TypeSignatureEquatableSet.Builder interfaces = TypeSignatures ??= new TypeSignatureEquatableSet.Builder(); diff --git a/src/WinRT.Interop.Generator/Extensions/CilInstructionExtensions.cs b/src/WinRT.Interop.Generator/Extensions/CilInstructionExtensions.cs index 140b96d6c..e8fcc6ef4 100644 --- a/src/WinRT.Interop.Generator/Extensions/CilInstructionExtensions.cs +++ b/src/WinRT.Interop.Generator/Extensions/CilInstructionExtensions.cs @@ -98,8 +98,38 @@ public static CilInstruction CreateStind(TypeSignature type, ModuleDefinition mo ElementType.R4 => new CilInstruction(Stind_R4), ElementType.R8 => new CilInstruction(Stind_R8), ElementType.ValueType when type.Resolve() is { IsClass: true, IsEnum: true } => new CilInstruction(Stind_I4), + ElementType.I => new CilInstruction(Stind_I), _ => new CilInstruction(Stobj, type.Import(module).ToTypeDefOrRef()), }; } + + /// + /// Create a new instruction loading a value indirectly from a target location. + /// + /// The type of value to load. + /// The in use. + /// The instruction. + [SuppressMessage("Style", "IDE0072", Justification = "We use 'ldobj' for all other possible types.")] + public static CilInstruction CreateLdind(TypeSignature type, ModuleDefinition module) + { + return type.ElementType switch + { + ElementType.Boolean => new CilInstruction(Ldind_I1), + ElementType.Char => new CilInstruction(Ldind_I2), + ElementType.I1 => new CilInstruction(Ldind_I1), + ElementType.U1 => new CilInstruction(Ldind_I1), + ElementType.I2 => new CilInstruction(Ldind_I2), + ElementType.U2 => new CilInstruction(Ldind_I2), + ElementType.I4 => new CilInstruction(Ldind_I4), + ElementType.U4 => new CilInstruction(Ldind_I4), + ElementType.I8 => new CilInstruction(Ldind_I8), + ElementType.U8 => new CilInstruction(Ldind_I8), + ElementType.R4 => new CilInstruction(Ldind_R4), + ElementType.R8 => new CilInstruction(Ldind_R8), + ElementType.ValueType when type.Resolve() is { IsClass: true, IsEnum: true } => new CilInstruction(Ldind_I4), + ElementType.I => new CilInstruction(Ldind_I), + _ => new CilInstruction(Ldobj, type.Import(module).ToTypeDefOrRef()), + }; + } } } \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IEnumerator1Impl.cs b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IEnumerator1Impl.cs index 12462937c..c1ddf2a59 100644 --- a/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IEnumerator1Impl.cs +++ b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IEnumerator1Impl.cs @@ -269,7 +269,7 @@ public static MethodDefinition GetMany( } /// - /// Creates a for the get_Current export method. + /// Creates a for the get_Current or MoveNext export method. /// /// The name of the method to generate. /// The adapter method to forward the call to. @@ -307,7 +307,7 @@ private static MethodDefinition HasCurrentOrMoveNext( CilInstruction ldloc_0_returnHResult = new(Ldloc_0); CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); - // Create a method body for the 'get_HasCurrent' method + // Create a method body for the method boolMethod.CilMethodBody = new CilMethodBody() { // Declare 1 variable: diff --git a/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IList1Impl.cs b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IList1Impl.cs new file mode 100644 index 000000000..fc7ee5629 --- /dev/null +++ b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IList1Impl.cs @@ -0,0 +1,762 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using AsmResolver; +using AsmResolver.DotNet; +using AsmResolver.DotNet.Code.Cil; +using AsmResolver.DotNet.Signatures; +using AsmResolver.PE.DotNet.Cil; +using AsmResolver.PE.DotNet.Metadata.Tables; +using WindowsRuntime.InteropGenerator.Generation; +using WindowsRuntime.InteropGenerator.References; +using static AsmResolver.PE.DotNet.Cil.CilOpCodes; + +#pragma warning disable IDE1006 + +namespace WindowsRuntime.InteropGenerator.Factories; + +/// +/// A factory for interop method definitions. +/// +internal static partial class InteropMethodDefinitionFactory +{ + /// + /// Helpers for impl types for interfaces. + /// + public static class IList1Impl + { + /// + /// Creates a for the GetView export method. + /// + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static MethodDefinition GetView( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'GetView' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int GetView(void* thisPtr, void** result) + MethodDefinition getViewMethod = new( + name: "GetView"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [ + module.CorLibTypeFactory.Void.MakePointerType(), + module.CorLibTypeFactory.Void.MakePointerType().MakePointerType()])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction nop_beforeTry = new(Nop); + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + CilInstruction nop_convertToUnmanaged = new(Nop); + + // Create a method body for the 'GetView' method + getViewMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // Return 'E_POINTER' if the argument is 'null' + { Ldarg_1 }, + { Ldc_I4_0 }, + { Conv_U }, + { Bne_Un_S, nop_beforeTry.CreateLabel() }, + { Ldc_I4, unchecked((int)0x80004003) }, + { Ret }, + { nop_beforeTry }, + + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldarg_1 }, + { Ldloc_0 }, + { Call, interopReferences.IListAdapter1GetView(elementType).Import(module) }, + { nop_convertToUnmanaged }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + // Track the method for rewrite to marshal the result value + emitState.TrackRetValValueMethodRewrite( + retValType: interopReferences.IReadOnlyList1.MakeGenericReferenceType(elementType), + method: getViewMethod, + marker: nop_convertToUnmanaged); + + return getViewMethod; + } + + /// + /// Creates a for the SetAt export method. + /// + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static MethodDefinition SetAt( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + return SetAtOrInsertAt( + methodName: "SetAt"u8, + adapterMethod: interopReferences.IListAdapter1SetAt(elementType), + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + } + + /// + /// Creates a for the InsertAt export method. + /// + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static MethodDefinition InsertAt( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + return SetAtOrInsertAt( + methodName: "InsertAt"u8, + adapterMethod: interopReferences.IListAdapter1InsertAt(elementType), + listType: listType, + interopReferences: interopReferences, + emitState: emitState, + module: module); + } + + /// + /// Creates a for the RemoveAt export method. + /// + /// The for the type. + /// The instance to use. + /// The interop module being built. + public static MethodDefinition RemoveAt( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'RemoveAt' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int RemoveAt(void* thisPtr, uint index) + MethodDefinition removeAtMethod = new( + name: "RemoveAt"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [ + module.CorLibTypeFactory.Void.MakePointerType(), + module.CorLibTypeFactory.UInt32])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'RemoveAt' method + removeAtMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldloc_0 }, + { Ldarg_1 }, + { Call, interopReferences.IListAdapter1RemoveAt(elementType).Import(module) }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + return removeAtMethod; + } + + /// + /// Creates a for the Append export method. + /// + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static MethodDefinition Append( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'Append' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int Append(void* thisPtr, value) + MethodDefinition appendMethod = new( + name: "Append"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [ + module.CorLibTypeFactory.Void.MakePointerType(), + elementType.GetAbiType(interopReferences).Import(module)])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction nop_parameter1Rewrite = new(Nop); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'Append' method + appendMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldloc_0 }, + { nop_parameter1Rewrite }, + { Callvirt, interopReferences.ICollection1Add(elementType).Import(module) }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + // Track rewriting the parameter for this method + emitState.TrackManagedParameterMethodRewrite( + parameterType: elementType, + method: appendMethod, + marker: nop_parameter1Rewrite, + parameterIndex: 1); + + return appendMethod; + } + + /// + /// Creates a for the RemoveAtEnd export method. + /// + /// The for the type. + /// The instance to use. + /// The interop module being built. + public static MethodDefinition RemoveAtEnd( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'RemoveAtEnd' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int RemoveAtEnd(void* thisPtr) + MethodDefinition removeAtEndMethod = new( + name: "RemoveAtEnd"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [module.CorLibTypeFactory.Void.MakePointerType()])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'RemoveAtEnd' method + removeAtEndMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldloc_0 }, + { Call, interopReferences.IListAdapter1RemoveAtEnd(elementType).Import(module) }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + return removeAtEndMethod; + } + + /// + /// Creates a for the Clear export method. + /// + /// The for the type. + /// The instance to use. + /// The interop module being built. + public static MethodDefinition Clear( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'Clear' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int Clear(void* thisPtr) + MethodDefinition clearMethod = new( + name: "Clear"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [module.CorLibTypeFactory.Void.MakePointerType()])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'Clear' method + clearMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldloc_0 }, + { Call, interopReferences.ICollection1Clear(elementType).Import(module) }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + return clearMethod; + } + + /// + /// Creates a for the ReplaceAll export method. + /// + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static MethodDefinition ReplaceAll( + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + TypeSignature elementAbiType = elementType.GetAbiType(interopReferences); + + // Define the 'ReplaceAll' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int ReplaceAll(void* thisPtr, uint size, * items) + MethodDefinition replaceAllMethod = new( + name: "ReplaceAll"u8, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [ + module.CorLibTypeFactory.Void.MakePointerType(), + module.CorLibTypeFactory.UInt32, + elementAbiType.Import(module).MakePointerType()])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'uint' (for the 'i' loop variable) + // [2]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_i = new(module.CorLibTypeFactory.UInt32); + CilLocalVariable loc_2_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_2_nullCheck = new(Ldarg_2); + CilInstruction nop_beforeTry = new(Nop); + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction ldloc_1_loopCheck = new(Ldloc_1); + CilInstruction ldloc_0_loopStart = new(Ldloc_0); + CilInstruction nop_convertToManaged = new(Nop); + CilInstruction ldloc_2_returnHResult = new(Ldloc_2); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'ReplaceAll' method + replaceAllMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_i, loc_2_hresult }, + Instructions = + { + // Return 'S_OK' if the size is '0' + { Ldarg_1 }, + { Brtrue_S, ldarg_2_nullCheck.CreateLabel() }, + { Ldc_I4_0 }, + { Ret }, + + // Return 'E_POINTER' if the array is 'null' + { ldarg_2_nullCheck }, + { Ldc_I4_0 }, + { Conv_U }, + { Bne_Un_S, nop_beforeTry.CreateLabel() }, + { Ldc_I4, unchecked((int)0x80004003) }, + { Ret }, + { nop_beforeTry }, + + // '.try' code to load the list + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + + // list.Clear(); + { Ldloc_0 }, + { Callvirt, interopReferences.ICollection1Clear(elementType).Import(module) }, + + // int i = 0; + { Ldc_I4_0 }, + { Stloc_1 }, + { Br_S, ldloc_1_loopCheck.CreateLabel() }, + + // list.Add((items[i])); + { ldloc_0_loopStart }, + { Ldarg_2 }, + { Ldloc_1 }, + { Conv_U8 }, + { Sizeof, elementAbiType.Import(module).ToTypeDefOrRef() }, + { Conv_I8 }, + { Mul }, + { Conv_I }, + { Add }, + { CilInstruction.CreateLdind(elementAbiType, module) }, + { nop_convertToManaged }, + { Callvirt, interopReferences.ICollection1Add(elementType).Import(module) }, + + // i++; + { Ldloc_1 }, + { Ldc_I4_1 }, + { Add }, + { Stloc_1 }, + + // if (i < size) goto LoopStart; + { ldloc_1_loopCheck }, + { Ldarg_1 }, + { Blt_Un_S, ldloc_0_loopStart.CreateLabel() }, + + // return S_OK + { Ldc_I4_0 }, + { Stloc_2 }, + { Leave_S, ldloc_2_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_2 }, + { Leave_S, ldloc_2_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_2_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_2_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + // Track rewriting each item for this method + emitState.TrackManagedValueMethodRewrite( + parameterType: elementType, + method: replaceAllMethod, + marker: nop_convertToManaged); + + return replaceAllMethod; + } + + /// + /// Creates a for the SetAt or InsertAt export method. + /// + /// The name of the method to generate. + /// The adapter method to forward the call to. + /// The for the type. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + private static MethodDefinition SetAtOrInsertAt( + Utf8String methodName, + MemberReference adapterMethod, + GenericInstanceTypeSignature listType, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + TypeSignature elementType = listType.TypeArguments[0]; + + // Define the 'SetAt' or 'InsertAt' method as follows: + // + // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + // private static int (void* thisPtr, uint index, value) + MethodDefinition setAtOrInsertAtMethod = new( + name: methodName, + attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, + signature: MethodSignature.CreateStatic( + returnType: module.CorLibTypeFactory.Int32, + parameterTypes: [ + module.CorLibTypeFactory.Void.MakePointerType(), + module.CorLibTypeFactory.UInt32, + elementType.GetAbiType(interopReferences).Import(module)])) + { + CustomAttributes = { InteropCustomAttributeFactory.UnmanagedCallersOnly(interopReferences, module) } + }; + + // Declare the local variables: + // [0]: '' (for 'thisObject') + // [1]: 'int' (the 'HRESULT' to return) + CilLocalVariable loc_0_thisObject = new(listType.Import(module)); + CilLocalVariable loc_1_hresult = new(module.CorLibTypeFactory.Int32); + + // Labels for jumps + CilInstruction ldarg_0_tryStart = new(Ldarg_0); + CilInstruction nop_parameter2Rewrite = new(Nop); + CilInstruction ldloc_1_returnHResult = new(Ldloc_1); + CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + + // Create a method body for the 'SetAt' or 'InsertAt' method + setAtOrInsertAtMethod.CilMethodBody = new CilMethodBody() + { + LocalVariables = { loc_0_thisObject, loc_1_hresult }, + Instructions = + { + // '.try' code + { ldarg_0_tryStart }, + { Call, interopReferences.ComInterfaceDispatchGetInstance.MakeGenericInstanceMethod(listType).Import(module) }, + { Stloc_0 }, + { Ldloc_0 }, + { Ldarg_1 }, + { nop_parameter2Rewrite }, + { Call, adapterMethod.Import(module) }, + { Ldc_I4_0 }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // '.catch' code + { call_catchStartMarshalException }, + { Stloc_1 }, + { Leave_S, ldloc_1_returnHResult.CreateLabel() }, + + // Return the 'HRESULT' from location [1] + { ldloc_1_returnHResult }, + { Ret } + }, + ExceptionHandlers = + { + new CilExceptionHandler + { + HandlerType = CilExceptionHandlerType.Exception, + TryStart = ldarg_0_tryStart.CreateLabel(), + TryEnd = call_catchStartMarshalException.CreateLabel(), + HandlerStart = call_catchStartMarshalException.CreateLabel(), + HandlerEnd = ldloc_1_returnHResult.CreateLabel(), + ExceptionType = interopReferences.Exception.Import(module) + } + } + }; + + // Track rewriting the parameter for this method + emitState.TrackManagedParameterMethodRewrite( + parameterType: elementType, + method: setAtOrInsertAtMethod, + marker: nop_parameter2Rewrite, + parameterIndex: 2); + + return setAtOrInsertAtMethod; + } + } +} \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IReadOnlyList1Impl.cs b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IReadOnlyList1Impl.cs index a91a51881..82a02a032 100644 --- a/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IReadOnlyList1Impl.cs +++ b/src/WinRT.Interop.Generator/Factories/InteropMethodDefinitionFactory.IReadOnlyList1Impl.cs @@ -28,11 +28,16 @@ public static class IReadOnlyList1Impl /// Creates a for the GetAt export method. /// /// The for the type. + /// The interface method to invoke on . /// The instance to use. /// The emit state for this invocation. /// The interop module being built. + /// + /// This method can also be used to define the GetAt method for interfaces. + /// public static MethodDefinition GetAt( GenericInstanceTypeSignature readOnlyListType, + MemberReference getAtMethod, InteropReferences interopReferences, InteropGeneratorEmitState emitState, ModuleDefinition module) @@ -43,7 +48,7 @@ public static MethodDefinition GetAt( // // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] // private static int GetAt(void* thisPtr, uint index, * result) - MethodDefinition getAtMethod = new( + MethodDefinition getAtImplMethod = new( name: "GetAt"u8, attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, signature: MethodSignature.CreateStatic( @@ -70,7 +75,7 @@ public static MethodDefinition GetAt( CilInstruction nop_convertToUnmanaged = new(Nop); // Create a method body for the 'GetAt' method - getAtMethod.CilMethodBody = new CilMethodBody() + getAtImplMethod.CilMethodBody = new CilMethodBody() { LocalVariables = { loc_0_thisObject, loc_1_hresult }, Instructions = @@ -91,7 +96,7 @@ public static MethodDefinition GetAt( { Ldarg_2 }, { Ldloc_0 }, { Ldarg_1 }, - { Call, interopReferences.IReadOnlyListAdapter1GetAt(elementType).Import(module) }, + { Call, getAtMethod.Import(module) }, { nop_convertToUnmanaged }, { Ldc_I4_0 }, { Stloc_1 }, @@ -123,30 +128,33 @@ public static MethodDefinition GetAt( // Track the method for rewrite to marshal the result value emitState.TrackRetValValueMethodRewrite( retValType: elementType, - method: getAtMethod, + method: getAtImplMethod, marker: nop_convertToUnmanaged); - return getAtMethod; + return getAtImplMethod; } /// /// Creates a for the get_Size export method. /// /// The for the type. + /// The interface method to invoke on . /// The instance to use. /// The interop module being built. + /// + /// This method can also be used to define the GetAt method for interfaces. + /// public static MethodDefinition get_Size( GenericInstanceTypeSignature readOnlyListType, + MemberReference sizeMethod, InteropReferences interopReferences, ModuleDefinition module) { - TypeSignature elementType = readOnlyListType.TypeArguments[0]; - // Define the 'get_Size' method as follows: // // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] // private static int get_Size(void* thisPtr, uint* result) - MethodDefinition sizeMethod = new( + MethodDefinition sizeImplMethod = new( name: "get_Size"u8, attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, signature: MethodSignature.CreateStatic( @@ -171,7 +179,7 @@ public static MethodDefinition get_Size( CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); // Create a method body for the 'get_Size' method - sizeMethod.CilMethodBody = new CilMethodBody() + sizeImplMethod.CilMethodBody = new CilMethodBody() { LocalVariables = { loc_0_thisObject, loc_1_hresult }, Instructions = @@ -191,7 +199,7 @@ public static MethodDefinition get_Size( { Stloc_0 }, { Ldarg_1 }, { Ldloc_0 }, - { Call, interopReferences.IReadOnlyListAdapter1Size(elementType).Import(module) }, + { Call, sizeMethod.Import(module) }, { Stind_I4 }, { Ldc_I4_0 }, { Stloc_1 }, @@ -220,18 +228,20 @@ public static MethodDefinition get_Size( } }; - return sizeMethod; + return sizeImplMethod; } /// /// Creates a for the IndexOf export method. /// /// The for the type. + /// The interface method to invoke on . /// The instance to use. /// The emit state for this invocation. /// The interop module being built. public static MethodDefinition IndexOf( GenericInstanceTypeSignature readOnlyListType, + MemberReference indexOfMethod, InteropReferences interopReferences, InteropGeneratorEmitState emitState, ModuleDefinition module) @@ -242,7 +252,7 @@ public static MethodDefinition IndexOf( // // [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] // private static int IndexOf(void* thisPtr, value, uint* index, bool* result) - MethodDefinition indexOfMethod = new( + MethodDefinition indexOfImplMethod = new( name: "IndexOf"u8, attributes: MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, signature: MethodSignature.CreateStatic( @@ -270,13 +280,23 @@ public static MethodDefinition IndexOf( CilInstruction ldloc_1_returnHResult = new(Ldloc_1); CilInstruction call_catchStartMarshalException = new(Call, interopReferences.RestrictedErrorInfoExceptionMarshallerConvertToUnmanaged.Import(module)); + MemberReference adapterIndexOfMethod; + // Get the target 'IndexOf' method (we can optimize for 'string' types) - MemberReference adapterIndexOfMethod = elementType.IsTypeOfString() - ? interopReferences.IReadOnlyListAdapterOfStringIndexOf() - : interopReferences.IReadOnlyListAdapter1IndexOf(elementType); + if (elementType.IsTypeOfString()) + { + adapterIndexOfMethod = SignatureComparer.IgnoreVersion.Equals(readOnlyListType.GenericType, interopReferences.IReadOnlyList1) + ? interopReferences.IReadOnlyListAdapterOfStringIndexOf() + : interopReferences.IListAdapterOfStringIndexOf(); + } + else + { + // Otherwise use the provided method directly (it will always be valid) + adapterIndexOfMethod = indexOfMethod; + } // Create a method body for the 'IndexOf' method - indexOfMethod.CilMethodBody = new CilMethodBody() + indexOfImplMethod.CilMethodBody = new CilMethodBody() { LocalVariables = { loc_0_thisObject, loc_1_hresult }, Instructions = @@ -339,11 +359,11 @@ public static MethodDefinition IndexOf( // Track rewriting the two parameters for this method emitState.TrackManagedParameterMethodRewrite( parameterType: parameterType, - method: indexOfMethod, + method: indexOfImplMethod, marker: nop_parameter1Rewrite, parameterIndex: 1); - return indexOfMethod; + return indexOfImplMethod; } #pragma warning disable IDE0017 @@ -351,11 +371,16 @@ public static MethodDefinition IndexOf( /// Creates a for the GetMany export method. /// /// The for the type. + /// The interface method to invoke on . /// The instance to use. /// The emit state for this invocation. /// The interop module being built. + /// /// + /// This method can also be used to define the GetMany method for interfaces. + /// public static MethodDefinition GetMany( GenericInstanceTypeSignature readOnlyListType, + MemberReference getAtMethod, InteropReferences interopReferences, InteropGeneratorEmitState emitState, ModuleDefinition module) diff --git a/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedParameter.cs b/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedParameter.cs index 3feba88bb..b878b08d9 100644 --- a/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedParameter.cs +++ b/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedParameter.cs @@ -69,9 +69,10 @@ public static void RewriteMethod( throw WellKnownInteropExceptions.MethodRewriteSourceParameterTypeMismatchError(source.ParameterType, parameterType, method); } + // See comments in the marshalling code for 'ManagedValue' for additional details on the code below. + // The two are identical, the only difference is this method also loads the parameters on the stack. if (parameterType.IsValueType) { - // If the return type is blittable, we can just load it directly it directly (simplest case) if (parameterType.IsBlittable(interopReferences)) { body.Instructions.ReferenceReplaceRange(marker, CilInstruction.CreateLdarg(parameterIndex)); @@ -80,26 +81,20 @@ public static void RewriteMethod( { InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); - // For 'Nullable' parameters (i.e. we have an 'IReference' interface pointer), we unbox the underlying type body.Instructions.ReferenceReplaceRange(marker, [ CilInstruction.CreateLdarg(parameterIndex), new CilInstruction(Call, marshallerType.UnboxToManaged().Import(module))]); } else if (SignatureComparer.IgnoreVersion.Equals(parameterType, interopReferences.ReadOnlySpanChar)) { - // When marshalling 'ReadOnlySpan' values, we also use 'HStringMarshaller', but without materializing the 'string' object body.Instructions.ReferenceReplaceRange(marker, [ CilInstruction.CreateLdarg(parameterIndex), new CilInstruction(Call, interopReferences.HStringMarshallerConvertToManagedUnsafe.Import(module))]); } else { - // The last case handles all other value types. It doesn't matter if they possibly hold some unmanaged - // resources, as they're only being used as parameters. That means the caller is responsible for disposal. - // This case can also handle 'KeyValuePair<,>' instantiations, which are just marshalled normally too. InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); - // We can directly call the marshaller and return it, no 'try/finally' complexity is needed body.Instructions.ReferenceReplaceRange(marker, [ CilInstruction.CreateLdarg(parameterIndex), new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))]); @@ -107,17 +102,14 @@ public static void RewriteMethod( } else if (parameterType.IsTypeOfString()) { - // When marshalling 'string' values, we must use 'HStringMarshaller' (the ABI type is not actually a COM object) body.Instructions.ReferenceReplaceRange(marker, [ CilInstruction.CreateLdarg(parameterIndex), new CilInstruction(Call, interopReferences.HStringMarshallerConvertToManaged.Import(module))]); } else { - // Get the marshaller type for all other reference types (including generics) InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); - // Marshal the value normally (the caller will own the native resource) body.Instructions.ReferenceReplaceRange(marker, [ CilInstruction.CreateLdarg(parameterIndex), new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))]); diff --git a/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedValue.cs b/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedValue.cs new file mode 100644 index 000000000..d4d0fe9b1 --- /dev/null +++ b/src/WinRT.Interop.Generator/Factories/InteropMethodRewriteFactory.ManagedValue.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using AsmResolver.DotNet; +using AsmResolver.DotNet.Code.Cil; +using AsmResolver.DotNet.Signatures; +using AsmResolver.PE.DotNet.Cil; +using WindowsRuntime.InteropGenerator.Errors; +using WindowsRuntime.InteropGenerator.Generation; +using WindowsRuntime.InteropGenerator.References; +using WindowsRuntime.InteropGenerator.Resolvers; +using static AsmResolver.PE.DotNet.Cil.CilOpCodes; + +namespace WindowsRuntime.InteropGenerator.Factories; + +/// +/// A factory to rewrite interop method definitons, and add marshalling code as needed. +/// +internal static partial class InteropMethodRewriteFactory +{ + /// + /// Contains the logic for marshalling managed values (i.e. parameters that are passed to managed methods, already on the stack). + /// + public static class ManagedValue + { + /// + /// Performs two-pass code generation on a target method to marshal an unmanaged parameter. + /// + /// The parameter type that needs to be marshalled. + /// The target method to perform two-pass code generation on. + /// The target IL instruction to replace with the right set of specialized instructions. + /// The instance to use. + /// The emit state for this invocation. + /// The interop module being built. + public static void RewriteMethod( + TypeSignature parameterType, + MethodDefinition method, + CilInstruction marker, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module) + { + // Validate that we do have some IL body for the input method (this should always be the case) + if (method.CilMethodBody is not CilMethodBody body) + { + throw WellKnownInteropExceptions.MethodRewriteMissingBodyError(method); + } + + // If we didn't find the marker, it means the target method is either invalid + if (!body.Instructions.ReferenceContains(marker)) + { + throw WellKnownInteropExceptions.MethodRewriteMarkerInstructionNotFoundError(marker, method); + } + + if (parameterType.IsValueType) + { + // If the return type is blittable, we have nothing else to do (the value is already loaded) + if (parameterType.IsBlittable(interopReferences)) + { + return; + } + + // Handle the other possible value types + if (parameterType.IsConstructedNullableValueType(interopReferences)) + { + InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); + + // For 'Nullable' parameters (i.e. we have an 'IReference' interface pointer), we unbox the underlying type + body.Instructions.ReferenceReplaceRange(marker, new CilInstruction(Call, marshallerType.UnboxToManaged().Import(module))); + } + else if (SignatureComparer.IgnoreVersion.Equals(parameterType, interopReferences.ReadOnlySpanChar)) + { + // When marshalling 'ReadOnlySpan' values, we also use 'HStringMarshaller', but without materializing the 'string' object + body.Instructions.ReferenceReplaceRange(marker, new CilInstruction(Call, interopReferences.HStringMarshallerConvertToManagedUnsafe.Import(module))); + } + else + { + // The last case handles all other value types. It doesn't matter if they possibly hold some unmanaged + // resources, as they're only being used as parameters. That means the caller is responsible for disposal. + // This case can also handle 'KeyValuePair<,>' instantiations, which are just marshalled normally too. + InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); + + // We can directly call the marshaller and return it, no 'try/finally' complexity is needed + body.Instructions.ReferenceReplaceRange(marker, new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))); + } + } + else if (parameterType.IsTypeOfString()) + { + // When marshalling 'string' values, we must use 'HStringMarshaller' (the ABI type is not actually a COM object) + body.Instructions.ReferenceReplaceRange(marker, new CilInstruction(Call, interopReferences.HStringMarshallerConvertToManaged.Import(module))); + } + else + { + // Get the marshaller type for all other reference types (including generics) + InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState); + + // Marshal the value normally (the caller will own the native resource) + body.Instructions.ReferenceReplaceRange(marker, new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))); + } + } + } +} \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs index d45f21950..4a92f1f3c 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs @@ -2040,6 +2040,17 @@ private static void RewriteMethodDefinitions( module: module); break; + // Rewrite managed values + case MethodRewriteInfo.ManagedValue managedValueInfo: + InteropMethodRewriteFactory.ManagedValue.RewriteMethod( + parameterType: managedValueInfo.Type, + method: managedValueInfo.Method, + marker: managedValueInfo.Marker, + interopReferences: interopReferences, + emitState: emitState, + module: module); + break; + // Rewrite managed parameters case MethodRewriteInfo.ManagedParameter managedParameterInfo: InteropMethodRewriteFactory.ManagedParameter.RewriteMethod( diff --git a/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs b/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs index 1ef8c2e9d..a0636c7e0 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs @@ -66,6 +66,9 @@ internal sealed class InteropGeneratorDiscoveryState /// Backing field for . private readonly ConcurrentDictionary _szArrayTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field to support . + private readonly ConcurrentDictionary _markedUserDefinedTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field for . private readonly ConcurrentDictionary _userDefinedTypes = new(SignatureComparer.IgnoreVersion); @@ -389,6 +392,17 @@ public void TrackSzArrayType(SzArrayTypeSignature szArrayType) _ = _szArrayTypes.TryAdd(szArrayType, 0); } + /// + /// Tries to mark a user-defined type as having been seen the first time, + /// and indicating that it's in the process of being processed. + /// + /// The user-defined type. + /// Whether this was the first time that was seen. + public bool TryMarkUserDefinedType(TypeSignature userDefinedType) + { + return _markedUserDefinedTypes.TryAdd(userDefinedType, 0); + } + /// /// Tracks a user-defined type. /// diff --git a/src/WinRT.Interop.Generator/Generation/InteropGeneratorEmitState.cs b/src/WinRT.Interop.Generator/Generation/InteropGeneratorEmitState.cs index 6718ea006..1009530b2 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGeneratorEmitState.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGeneratorEmitState.cs @@ -156,6 +156,27 @@ public void TrackManagedParameterMethodRewrite( }); } + /// + /// Tracks a method rewrite that involves marshalling a managed value in the specified method. + /// + /// + /// + /// + public void TrackManagedValueMethodRewrite( + TypeSignature parameterType, + MethodDefinition method, + CilInstruction marker) + { + ThrowIfReadOnly(); + + _methodRewriteInfos.Add(new MethodRewriteInfo.ManagedValue + { + Type = parameterType, + Method = method, + Marker = marker + }); + } + /// /// Tracks a method rewrite that involves loading a native parameter in the specified method. /// diff --git a/src/WinRT.Interop.Generator/Models/MethodRewriteInfo/MethodRewriteInfo.ManagedValue.cs b/src/WinRT.Interop.Generator/Models/MethodRewriteInfo/MethodRewriteInfo.ManagedValue.cs new file mode 100644 index 000000000..2987cb263 --- /dev/null +++ b/src/WinRT.Interop.Generator/Models/MethodRewriteInfo/MethodRewriteInfo.ManagedValue.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace WindowsRuntime.InteropGenerator.Models; + +/// +internal partial class MethodRewriteInfo +{ + /// + /// Contains info for a target method for two-pass IL generation, for a managed value. + /// + /// + public sealed class ManagedValue : MethodRewriteInfo + { + /// + public override int CompareTo(MethodRewriteInfo? other) + { + // 'ManagedValue' objects have no additional state, so just compare with the base state + return ReferenceEquals(this, other) + ? 0 + : CompareByMethodRewriteInfo(other); + } + } +} diff --git a/src/WinRT.Interop.Generator/References/InteropReferences.cs b/src/WinRT.Interop.Generator/References/InteropReferences.cs index 36fa36124..643651317 100644 --- a/src/WinRT.Interop.Generator/References/InteropReferences.cs +++ b/src/WinRT.Interop.Generator/References/InteropReferences.cs @@ -338,6 +338,11 @@ public InteropReferences( /// public TypeReference KeyValuePair2 => field ??= _corLibTypeFactory.CorLibScope.CreateTypeReference("System.Collections.Generic"u8, "KeyValuePair`2"u8); + /// + /// Gets the for . + /// + public TypeReference ReadOnlyCollection1 => field ??= _corLibTypeFactory.CorLibScope.CreateTypeReference("System.Collections.ObjectModel"u8, "ReadOnlyCollection`1"u8); + /// /// Gets the for . /// @@ -658,6 +663,16 @@ public InteropReferences( /// public TypeReference IListMethods1 => field ??= _windowsRuntimeModule.CreateTypeReference("WindowsRuntime.InteropServices"u8, "IListMethods`1"u8); + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>. + /// + public TypeReference IListAdapter1 => field ??= _windowsRuntimeModule.CreateTypeReference("WindowsRuntime.InteropServices"u8, "IListAdapter`1"u8); + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapterExtensions. + /// + public TypeReference IListAdapterExtensions => field ??= _windowsRuntimeModule.CreateTypeReference("WindowsRuntime.InteropServices"u8, "IListAdapterExtensions"u8); + /// /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>. /// @@ -3073,52 +3088,160 @@ public MethodSpecification IListMethods1Insert(TypeSignature elementType, TypeDe } /// - /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>.GetAt. + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.GetAt. /// /// The input element type. - public MemberReference IReadOnlyListAdapter1GetAt(TypeSignature elementType) + public MemberReference IListAdapter1GetAt(TypeSignature elementType) { - return IReadOnlyListAdapter1 + return IListAdapter1 .MakeGenericReferenceType(elementType) .ToTypeDefOrRef() .CreateMemberReference("GetAt"u8, MethodSignature.CreateStatic( returnType: new GenericParameterSignature(GenericParameterType.Type, 0), parameterTypes: [ - IReadOnlyList1.MakeGenericReferenceType(elementType), + IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), _corLibTypeFactory.UInt32])); } /// - /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>.IndexOf. + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.Size. /// /// The input element type. - public MemberReference IReadOnlyListAdapter1IndexOf(TypeSignature elementType) + public MemberReference IListAdapter1Size(TypeSignature elementType) { - return IReadOnlyListAdapter1 + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("Size"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.UInt32, + parameterTypes: [IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0))])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.GetView. + /// + /// The input element type. + public MemberReference IListAdapter1GetView(TypeSignature elementType) + { + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("GetView"u8, MethodSignature.CreateStatic( + returnType: IReadOnlyList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + parameterTypes: [IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0))])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.IndexOf. + /// + /// The input element type. + public MemberReference IListAdapter1IndexOf(TypeSignature elementType) + { + return IListAdapter1 .MakeGenericReferenceType(elementType) .ToTypeDefOrRef() .CreateMemberReference("IndexOf"u8, MethodSignature.CreateStatic( returnType: _corLibTypeFactory.Boolean, parameterTypes: [ - IReadOnlyList1.MakeGenericReferenceType(elementType), + IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), new GenericParameterSignature(GenericParameterType.Type, 0), _corLibTypeFactory.UInt32.MakeByReferenceType()])); } /// - /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<string>.IndexOf. + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<string>.IndexOf. /// - public MemberReference IReadOnlyListAdapterOfStringIndexOf() + public MemberReference IListAdapterOfStringIndexOf() { - return IReadOnlyListAdapterExtensions + return IListAdapterExtensions .CreateMemberReference("IndexOf"u8, MethodSignature.CreateStatic( returnType: _corLibTypeFactory.Boolean, parameterTypes: [ - IReadOnlyList1.MakeGenericReferenceType(_corLibTypeFactory.String), + IList1.MakeGenericReferenceType(_corLibTypeFactory.String), ReadOnlySpanChar, _corLibTypeFactory.UInt32.MakeByReferenceType()])); } + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.SetAt. + /// + /// The input element type. + public MemberReference IListAdapter1SetAt(TypeSignature elementType) + { + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("SetAt"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Void, + parameterTypes: [ + IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + _corLibTypeFactory.UInt32, + new GenericParameterSignature(GenericParameterType.Type, 0)])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.InsertAt. + /// + /// The input element type. + public MemberReference IListAdapter1InsertAt(TypeSignature elementType) + { + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("InsertAt"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Void, + parameterTypes: [ + IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + _corLibTypeFactory.UInt32, + new GenericParameterSignature(GenericParameterType.Type, 0)])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.RemoveAt. + /// + /// The input element type. + public MemberReference IListAdapter1RemoveAt(TypeSignature elementType) + { + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("RemoveAt"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Void, + parameterTypes: [ + IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + _corLibTypeFactory.UInt32])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IListAdapter<T>.RemoveAtEnd. + /// + /// The input element type. + public MemberReference IListAdapter1RemoveAtEnd(TypeSignature elementType) + { + return IListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("RemoveAtEnd"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Void, + parameterTypes: [IList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0))])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>.GetAt. + /// + /// The input element type. + public MemberReference IReadOnlyListAdapter1GetAt(TypeSignature elementType) + { + return IReadOnlyListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("GetAt"u8, MethodSignature.CreateStatic( + returnType: new GenericParameterSignature(GenericParameterType.Type, 0), + parameterTypes: [ + IReadOnlyList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + _corLibTypeFactory.UInt32])); + } + /// /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>.Size. /// @@ -3130,7 +3253,38 @@ public MemberReference IReadOnlyListAdapter1Size(TypeSignature elementType) .ToTypeDefOrRef() .CreateMemberReference("Size"u8, MethodSignature.CreateStatic( returnType: _corLibTypeFactory.UInt32, - parameterTypes: [IReadOnlyList1.MakeGenericReferenceType(elementType)])); + parameterTypes: [IReadOnlyList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0))])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<T>.IndexOf. + /// + /// The input element type. + public MemberReference IReadOnlyListAdapter1IndexOf(TypeSignature elementType) + { + return IReadOnlyListAdapter1 + .MakeGenericReferenceType(elementType) + .ToTypeDefOrRef() + .CreateMemberReference("IndexOf"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Boolean, + parameterTypes: [ + IReadOnlyList1.MakeGenericReferenceType(new GenericParameterSignature(GenericParameterType.Type, 0)), + new GenericParameterSignature(GenericParameterType.Type, 0), + _corLibTypeFactory.UInt32.MakeByReferenceType()])); + } + + /// + /// Gets the for WindowsRuntime.InteropServices.IReadOnlyListAdapter<string>.IndexOf. + /// + public MemberReference IReadOnlyListAdapterOfStringIndexOf() + { + return IReadOnlyListAdapterExtensions + .CreateMemberReference("IndexOf"u8, MethodSignature.CreateStatic( + returnType: _corLibTypeFactory.Boolean, + parameterTypes: [ + IReadOnlyList1.MakeGenericReferenceType(_corLibTypeFactory.String), + ReadOnlySpanChar, + _corLibTypeFactory.UInt32.MakeByReferenceType()])); } /// diff --git a/src/WinRT.Runtime2/InteropServices/Collections/IListAdapterExtensions.cs b/src/WinRT.Runtime2/InteropServices/Collections/IListAdapterExtensions.cs new file mode 100644 index 000000000..9bb52083c --- /dev/null +++ b/src/WinRT.Runtime2/InteropServices/Collections/IListAdapterExtensions.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.ComponentModel; + +namespace WindowsRuntime.InteropServices; + +/// +/// Extensions for the type. +/// +[Obsolete(WindowsRuntimeConstants.PrivateImplementationDetailObsoleteMessage, + DiagnosticId = WindowsRuntimeConstants.PrivateImplementationDetailObsoleteDiagnosticId, + UrlFormat = WindowsRuntimeConstants.CsWinRTDiagnosticsUrlFormat)] +[EditorBrowsable(EditorBrowsableState.Never)] +public static class IListAdapterExtensions +{ + extension(IListAdapter) + { + /// + /// + public static bool IndexOf(IList list, ReadOnlySpan value, out uint index) + { + int count = list.Count; + + for (int i = 0; i < count; i++) + { + if (list[i].SequenceEqual(value)) + { + index = (uint)i; + + return true; + } + } + + index = 0; + + return false; + } + } +} \ No newline at end of file diff --git a/src/WinRT.Runtime2/InteropServices/Collections/IListAdapter{T}.cs b/src/WinRT.Runtime2/InteropServices/Collections/IListAdapter{T}.cs new file mode 100644 index 000000000..497d1dd36 --- /dev/null +++ b/src/WinRT.Runtime2/InteropServices/Collections/IListAdapter{T}.cs @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace WindowsRuntime.InteropServices; + +/// +/// A stateless adapter for , to be exposed as Windows.Foundation.Collections.IVector<T>. +/// +/// The type of objects to enumerate. +/// +[Obsolete(WindowsRuntimeConstants.PrivateImplementationDetailObsoleteMessage, + DiagnosticId = WindowsRuntimeConstants.PrivateImplementationDetailObsoleteDiagnosticId, + UrlFormat = WindowsRuntimeConstants.CsWinRTDiagnosticsUrlFormat)] +[EditorBrowsable(EditorBrowsableState.Never)] +public static class IListAdapter +{ + /// + /// Returns the item at the specified index in the vector. + /// + /// The wrapped instance. + /// The zero-based index of the item. + /// The item at the specified index. + /// + public static T GetAt(IList list, uint index) + { + IReadOnlyListAdapterHelpers.EnsureIndexInValidRange(index, list.Count); + + try + { + return list[(int)index]; + } + catch (ArgumentOutOfRangeException e) + { + e.HResult = WellKnownErrorCodes.E_BOUNDS; + + throw; + } + } + + /// + /// Gets the number of items in the vector. + /// + /// The wrapped instance. + /// The number of items in the vector. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Size(IList list) + { + return (uint)list.Count; + } + + /// + /// Returns an immutable view of the vector. + /// + /// The wrapped instance. + /// The view of the vector. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static IReadOnlyList GetView(IList list) + { + // This list is not really read-only: once marshalled, native code could do + // a 'QueryInterface' call back to 'IVector', which would succeed, and would + // return a modifiable reference for this view. We believe this is accetable, as + // it allows us to gain some performance. For instance, in most situations (because + // pretty much all built-in .NET collection types implementing 'IList' also implement + // 'IReadOnlyList'), this allows us to not allocate anything. That is, when native + // code calls 'GetView()', it would get back the same CCW instance, just through a + // 'QueryInterface' call for 'IVectorView' instead. + return list as IReadOnlyList ?? new ReadOnlyCollection(list); + } + + /// + /// Retrieves the index of a specified item in the vector. + /// + /// The wrapped instance. + /// The item to find in the vector. + /// If the item is found, this is the zero-based index of the item; otherwise, this parameter is 0. + /// if the item is found; otherwise, . + /// + public static bool IndexOf(IList list, T value, out uint index) + { + int count = list.Count; + + // Scan the list and look for the target item + for (int i = 0; i < count; i++) + { + if (EqualityComparer.Default.Equals(value, list[i])) + { + index = (uint)i; + + return true; + } + } + + // Same as 'IndexOf' for 'IReadOnlyList', see notes there + index = 0; + + return false; + } + + /// + /// Sets the value at the specified index in the vector. + /// + /// The wrapped instance. + /// The zero-based index at which to set the value. + /// The item to set. + /// + public static void SetAt(IList list, uint index, T value) + { + IReadOnlyListAdapterHelpers.EnsureIndexInValidRange(index, list.Count); + + try + { + list[(int)index] = value; + } + catch (ArgumentOutOfRangeException e) + { + e.HResult = WellKnownErrorCodes.E_BOUNDS; + + throw; + } + } + + /// + /// Inserts an item at a specified index in the vector. + /// + /// The wrapped instance. + /// The zero-based index. + /// The item to insert. + /// + public static void InsertAt(IList list, uint index, T value) + { + // Inserting at an index one past the end of the list is equivalent to just + // appending an item, so we need to ensure that we're within '[0, count + 1)'. + IReadOnlyListAdapterHelpers.EnsureIndexInValidRange(index, list.Count + 1); + + try + { + list.Insert((int)index, value); + } + catch (ArgumentOutOfRangeException e) + { + e.HResult = WellKnownErrorCodes.E_BOUNDS; + + throw; + } + } + + /// + /// Removes the item at the specified index in the vector. + /// + /// The wrapped instance. + /// The zero-based index of the vector item to remove. + /// + public static void RemoveAt(IList list, uint index) + { + IReadOnlyListAdapterHelpers.EnsureIndexInValidRange(index, list.Count); + + try + { + list.RemoveAt((int)index); + } + catch (ArgumentOutOfRangeException e) + { + e.HResult = WellKnownErrorCodes.E_BOUNDS; + + throw; + } + } + + /// + /// Removes the last item from the vector. + /// + /// The wrapped instance. + /// + public static void RemoveAtEnd(IList list) + { + // Manually hoist the count to avoid doing an interface dispatch call twice. + // We don't need to protect against mutation here: the actual list type would + // already either handle this, or the following 'RemoveAt' call might throw. + int count = list.Count; + + // Check that the list isn't empty, as that would of course be invalid + if (count == 0) + { + [DoesNotReturn] + static void ThrowInvalidOperationException() + { + throw new InvalidOperationException("InvalidOperation_CannotRemoveLastFromEmptyCollection") + { + HResult = WellKnownErrorCodes.E_BOUNDS + }; + } + + ThrowInvalidOperationException(); + } + + try + { + list.RemoveAt(count - 1); + } + catch (ArgumentOutOfRangeException e) + { + e.HResult = WellKnownErrorCodes.E_BOUNDS; + + throw; + } + } + + /// + /// Retrieves multiple items from the vector view beginning at the given index. + /// + /// The wrapped instance. + /// The zero-based index to start at. + /// The target to write items into. + /// The number of items that were retrieved. This value can be less than the size of if the end of the list is reached. + /// + public static int GetMany(IList list, uint startIndex, Span items) + { + int count = list.Count; + + // See notes in 'GetMany' for 'IReadOnlyList' + if (startIndex == count) + { + return 0; + } + + IReadOnlyListAdapterHelpers.EnsureIndexInValidRange(startIndex, count); + + // Empty spans are supported, we just stop immediately + if (items.IsEmpty) + { + return 0; + } + + int itemCount = int.Min(items.Length, count - (int)startIndex); + + // Copy all items to the target span + for (int i = 0; i < itemCount; ++i) + { + items[i] = list[i + (int)startIndex]; + } + + return itemCount; + } +} \ No newline at end of file