From dbe34996fff2e1a19f2391fd1f0576c7d0ffc960 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Wed, 17 Jun 2026 09:30:05 +0000 Subject: [PATCH 1/2] Fix C# generator base model recursion Make base model provider traversal cycle-safe across model construction, MRW serialization, and ModelReaderWriter context collection. Also emit explicit response conversion hiding at the ClientModel layer when a base model already exposes the same conversion. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ModelReaderWriterContextDefinition.cs | 5 +- .../MrwSerializationTypeDefinition.cs | 67 ++++++++++++++----- .../src/Providers/ModelProvider.cs | 61 +++++++++++++---- 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs index a7080cbf090..c5ee59fe002 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs @@ -183,7 +183,10 @@ private void CollectBuildableTypesRecursiveCore( if (provider is ModelProvider modelProvider && modelProvider.BaseModelProvider != null) { // For base model types, we need to process their properties as well, but we don't need to add the base model type itself - CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + if (visitedTypeProviders.Add(modelProvider.BaseModelProvider)) + { + CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + } } else { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 0d02ecba187..f838e429e6b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -291,6 +291,10 @@ private MethodProvider BuildExplicitFromClientResult() ScmCodeModelGenerator.Instance.TypeFactory.ClientResponseApi.ClientResponseType); var modifiers = MethodSignatureModifiers.Public | MethodSignatureModifiers.Static | MethodSignatureModifiers.Explicit | MethodSignatureModifiers.Operator; + if (HasBaseExplicitFromClientResult()) + { + modifiers |= MethodSignatureModifiers.New; + } // using PipelineResponse response = result.GetRawResponse(); var response = result.ToApi(); MethodBodyStatement responseDeclaration; @@ -350,6 +354,25 @@ private MethodProvider BuildImplicitToBinaryContent() this); } + private bool HasBaseExplicitFromClientResult() + => EnumerateBaseModelProviders() + .SelectMany(provider => provider.SerializationProviders) + .OfType() + .SelectMany(provider => provider.Methods) + .Any(IsExplicitFromClientResultMethod); + + private static bool IsExplicitFromClientResultMethod(MethodProvider method) + { + if (method.Signature.Parameters is not [var parameter] + || !parameter.Type.AreNamesEqual(ScmCodeModelGenerator.Instance.TypeFactory.ClientResponseApi.ClientResponseType)) + { + return false; + } + + return method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Operator) + && method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Explicit); + } + private MethodProvider BuildToBinaryContentMethod() { var formatParameter = new ParameterProvider("format", $"The format to use for serialization", typeof(string)); @@ -1114,8 +1137,7 @@ private List BuildDeserializePropertiesStatements(ScopedApi var rawBinaryData = _rawDataField; if (rawBinaryData == null) { - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); if (field != null) @@ -1123,7 +1145,6 @@ private List BuildDeserializePropertiesStatements(ScopedApi rawBinaryData = field; break; } - baseModelProvider = baseModelProvider.BaseModelProvider; } } @@ -1737,8 +1758,7 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode if (isDynamicModelWithNonDynamicBase) { - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { foreach (var property in baseModelProvider.CanonicalView.Properties) { @@ -1759,8 +1779,6 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode propertyStatements.Add(CreateWritePropertyStatement(field.WireInfo, field.Type, field.Name, field, field.WireInfo?.SerializationFormat)); } - - baseModelProvider = baseModelProvider.BaseModelProvider; } } @@ -2575,21 +2593,20 @@ private MethodBodyStatement CreateWriteAdditionalPropertiesStatement() PropertyProvider? property = _model.Properties.FirstOrDefault( p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); // search in the base model if the property is not found in the current model - return property ?? _model.BaseModelProvider?.Properties.FirstOrDefault( - p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); + return property ?? EnumerateBaseModelProviders() + .SelectMany(m => m.Properties) + .FirstOrDefault(p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); } private MethodProvider? FindCustomHookMethod(string hookName) { - var model = _model; - while (model != null) + foreach (var model in EnumerateModelAndBaseModelProviders()) { var method = model.CanonicalView.Methods.FirstOrDefault(m => m.Signature.Name == hookName); if (method != null) { return method; } - model = model.BaseModelProvider; } return null; } @@ -2636,9 +2653,8 @@ private List GetSerializationAttributes() List serializationAttributes = _model.CustomCodeView?.Attributes .Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName) .ToList() ?? []; - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { var customCodeView = baseModelProvider.CustomCodeView; if (customCodeView != null) @@ -2647,12 +2663,33 @@ private List GetSerializationAttributes() .AddRange(customCodeView.Attributes .Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName)); } - baseModelProvider = baseModelProvider.BaseModelProvider; } return serializationAttributes; } + private IEnumerable EnumerateModelAndBaseModelProviders() + { + var visited = new HashSet(); + var model = _model; + while (model != null && visited.Add(model)) + { + yield return model; + model = model.BaseModelProvider; + } + } + + private IEnumerable EnumerateBaseModelProviders() + { + var visited = new HashSet { _model }; + var model = _model.BaseModelProvider; + while (model != null && visited.Add(model)) + { + yield return model; + model = model.BaseModelProvider; + } + } + private static bool TypeRequiresNullCheckInSerialization(CSharpType type) { if (type.IsCollection) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs index c8b4adbb74e..2af6016e335 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs @@ -59,6 +59,7 @@ protected override FormattableString BuildDescription() private readonly Type _additionalPropsUnknownType = typeof(BinaryData); private Lazy _useObjectAdditionalProperties; private FieldProvider? _rawDataField; + private bool _buildingRawDataField; private List? _additionalPropertyFields; private List? _additionalPropertyProperties; private ModelProvider? _baseModelProvider; @@ -178,7 +179,31 @@ public override void Reset() _fullConstructor = null; } - protected FieldProvider? RawDataField => _rawDataField ??= BuildRawDataField(); + protected FieldProvider? RawDataField + { + get + { + if (_rawDataField is not null) + { + return _rawDataField; + } + + if (_buildingRawDataField) + { + return null; + } + + _buildingRawDataField = true; + try + { + return _rawDataField = BuildRawDataField(); + } + finally + { + _buildingRawDataField = false; + } + } + } protected virtual bool ShouldSkipDerivedModelProperties => false; /// /// Gets whether derived models should skip overriding serialization methods from this base model. @@ -648,8 +673,9 @@ private IEnumerable EnumerateBaseModels() private IEnumerable EnumerateBaseModelProviders() { + HashSet visited = [this]; var model = BaseModelProvider; - while (model != null) + while (model != null && visited.Add(model)) { yield return model; model = model.BaseModelProvider; @@ -859,9 +885,8 @@ private bool ParametersMatch(IReadOnlyList params1, IReadOnly private IEnumerable GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false) { var properties = new Stack>(); - var modelProvider = BaseModelProvider; bool isDirectBase = true; - while (modelProvider != null) + foreach (var modelProvider in EnumerateBaseModelProviders()) { properties.Push([]); foreach (var property in modelProvider.CanonicalView.Properties) @@ -880,7 +905,6 @@ private IEnumerable GetAllBasePropertiesForConstructorInitiali } } - modelProvider = modelProvider.BaseModelProvider; isDirectBase = false; } @@ -891,15 +915,13 @@ private IEnumerable GetAllBasePropertiesForConstructorInitiali private IEnumerable GetAllBaseFieldsForConstructorInitialization() { var fields = new Stack>(); - var modelProvider = BaseModelProvider; - while (modelProvider != null) + foreach (var modelProvider in EnumerateBaseModelProviders()) { fields.Push([]); foreach (var field in modelProvider.CanonicalView.Fields) { fields.Peek().Add(field); } - modelProvider = modelProvider.BaseModelProvider; } return fields.SelectMany(l => l); @@ -918,7 +940,7 @@ private IEnumerable GetAllBaseFieldsForConstructorInitialization( baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter); baseFields = GetAllBaseFieldsForConstructorInitialization(); } - else if (BaseModelProvider?.FullConstructor.Signature != null) + else if (BaseModelProvider is not null && !HasBaseModelProviderCycle()) { baseParameters.AddRange(BaseModelProvider.FullConstructor.Signature.Parameters); } @@ -1003,6 +1025,23 @@ p.Property is null return (constructorParameters, constructorInitializer); } + private bool HasBaseModelProviderCycle() + { + HashSet visited = [this]; + var modelProvider = BaseModelProvider; + while (modelProvider != null) + { + if (!visited.Add(modelProvider)) + { + return true; + } + + modelProvider = modelProvider.BaseModelProvider; + } + + return false; + } + private ValueExpression? EnsureDiscriminatorValueExpression() { if (_inputModel.BaseModel is not null && _inputModel.DiscriminatorValue is not null) @@ -1300,14 +1339,12 @@ private static ValueExpression GetConversion(PropertyProvider? property = defaul } // check if there is a raw data field on any of the base models, if so, we do not have to have one here. - var baseModelProvider = BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { if (baseModelProvider.RawDataField != null) { return null; } - baseModelProvider = baseModelProvider.BaseModelProvider; } var modifiers = FieldModifiers.Private; From f5751392b9840a816021a32832a0930f7759fb11 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Wed, 17 Jun 2026 11:12:37 +0000 Subject: [PATCH 2/2] Fix model reader writer recursion handling --- .../ModelReaderWriterContextDefinition.cs | 17 +++++++------- .../MrwSerializationTypeDefinition.cs | 23 ------------------- 2 files changed, 8 insertions(+), 32 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs index c5ee59fe002..c0dd4b128db 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs @@ -89,11 +89,12 @@ protected override IReadOnlyList BuildAttributes() // Process all providers from the output library to discover types from methods and properties var providers = ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders; + var outputLibraryProviders = new HashSet(providers, s_typeProviderNameComparer); // Process each provider recursively foreach (var provider in providers) { - CollectBuildableTypeProvidersRecursive(provider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + CollectBuildableTypeProvidersRecursive(provider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } return (buildableTypes, buildableProviders); @@ -120,6 +121,7 @@ private void CollectBuildableTypesRecursive( /// private void CollectBuildableTypeProvidersRecursive( TypeProvider currentProvider, + HashSet outputLibraryProviders, HashSet visitedTypes, HashSet visitedTypeProviders, HashSet buildableProviders, @@ -131,8 +133,8 @@ private void CollectBuildableTypeProvidersRecursive( return; } - // Only add to buildableProviders if it implements MRW - if (ImplementsModelReaderWriter(currentProvider)) + // Only add to buildableProviders if it implements MRW and belongs to the output library. + if (outputLibraryProviders.Contains(currentProvider) && ImplementsModelReaderWriter(currentProvider)) { buildableProviders.Add(currentProvider); } @@ -140,12 +142,13 @@ private void CollectBuildableTypeProvidersRecursive( // Process all providers to discover types from methods and properties if (currentProvider is not null) { - CollectBuildableTypesRecursiveCore(currentProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + CollectBuildableTypesRecursiveCore(currentProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } } private void CollectBuildableTypesRecursiveCore( TypeProvider provider, + HashSet outputLibraryProviders, HashSet visitedTypes, HashSet visitedTypeProviders, HashSet buildableProviders, @@ -182,11 +185,7 @@ private void CollectBuildableTypesRecursiveCore( if (provider is ModelProvider modelProvider && modelProvider.BaseModelProvider != null) { - // For base model types, we need to process their properties as well, but we don't need to add the base model type itself - if (visitedTypeProviders.Add(modelProvider.BaseModelProvider)) - { - CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); - } + CollectBuildableTypeProvidersRecursive(modelProvider.BaseModelProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } else { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index f838e429e6b..c838feb4750 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -291,10 +291,6 @@ private MethodProvider BuildExplicitFromClientResult() ScmCodeModelGenerator.Instance.TypeFactory.ClientResponseApi.ClientResponseType); var modifiers = MethodSignatureModifiers.Public | MethodSignatureModifiers.Static | MethodSignatureModifiers.Explicit | MethodSignatureModifiers.Operator; - if (HasBaseExplicitFromClientResult()) - { - modifiers |= MethodSignatureModifiers.New; - } // using PipelineResponse response = result.GetRawResponse(); var response = result.ToApi(); MethodBodyStatement responseDeclaration; @@ -354,25 +350,6 @@ private MethodProvider BuildImplicitToBinaryContent() this); } - private bool HasBaseExplicitFromClientResult() - => EnumerateBaseModelProviders() - .SelectMany(provider => provider.SerializationProviders) - .OfType() - .SelectMany(provider => provider.Methods) - .Any(IsExplicitFromClientResultMethod); - - private static bool IsExplicitFromClientResultMethod(MethodProvider method) - { - if (method.Signature.Parameters is not [var parameter] - || !parameter.Type.AreNamesEqual(ScmCodeModelGenerator.Instance.TypeFactory.ClientResponseApi.ClientResponseType)) - { - return false; - } - - return method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Operator) - && method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Explicit); - } - private MethodProvider BuildToBinaryContentMethod() { var formatParameter = new ParameterProvider("format", $"The format to use for serialization", typeof(string));