diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs index a7080cbf090..c0dd4b128db 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ModelReaderWriterContextDefinition.cs @@ -89,11 +89,12 @@ protected override IReadOnlyList BuildAttributes() // Process all providers from the output library to discover types from methods and properties var providers = ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders; + var outputLibraryProviders = new HashSet(providers, s_typeProviderNameComparer); // Process each provider recursively foreach (var provider in providers) { - CollectBuildableTypeProvidersRecursive(provider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + CollectBuildableTypeProvidersRecursive(provider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } return (buildableTypes, buildableProviders); @@ -120,6 +121,7 @@ private void CollectBuildableTypesRecursive( /// private void CollectBuildableTypeProvidersRecursive( TypeProvider currentProvider, + HashSet outputLibraryProviders, HashSet visitedTypes, HashSet visitedTypeProviders, HashSet buildableProviders, @@ -131,8 +133,8 @@ private void CollectBuildableTypeProvidersRecursive( return; } - // Only add to buildableProviders if it implements MRW - if (ImplementsModelReaderWriter(currentProvider)) + // Only add to buildableProviders if it implements MRW and belongs to the output library. + if (outputLibraryProviders.Contains(currentProvider) && ImplementsModelReaderWriter(currentProvider)) { buildableProviders.Add(currentProvider); } @@ -140,12 +142,13 @@ private void CollectBuildableTypeProvidersRecursive( // Process all providers to discover types from methods and properties if (currentProvider is not null) { - CollectBuildableTypesRecursiveCore(currentProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + CollectBuildableTypesRecursiveCore(currentProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } } private void CollectBuildableTypesRecursiveCore( TypeProvider provider, + HashSet outputLibraryProviders, HashSet visitedTypes, HashSet visitedTypeProviders, HashSet buildableProviders, @@ -182,8 +185,7 @@ private void CollectBuildableTypesRecursiveCore( if (provider is ModelProvider modelProvider && modelProvider.BaseModelProvider != null) { - // For base model types, we need to process their properties as well, but we don't need to add the base model type itself - CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); + CollectBuildableTypeProvidersRecursive(modelProvider.BaseModelProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes); } else { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 0d02ecba187..c838feb4750 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -1114,8 +1114,7 @@ private List BuildDeserializePropertiesStatements(ScopedApi var rawBinaryData = _rawDataField; if (rawBinaryData == null) { - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); if (field != null) @@ -1123,7 +1122,6 @@ private List BuildDeserializePropertiesStatements(ScopedApi rawBinaryData = field; break; } - baseModelProvider = baseModelProvider.BaseModelProvider; } } @@ -1737,8 +1735,7 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode if (isDynamicModelWithNonDynamicBase) { - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { foreach (var property in baseModelProvider.CanonicalView.Properties) { @@ -1759,8 +1756,6 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode propertyStatements.Add(CreateWritePropertyStatement(field.WireInfo, field.Type, field.Name, field, field.WireInfo?.SerializationFormat)); } - - baseModelProvider = baseModelProvider.BaseModelProvider; } } @@ -2575,21 +2570,20 @@ private MethodBodyStatement CreateWriteAdditionalPropertiesStatement() PropertyProvider? property = _model.Properties.FirstOrDefault( p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); // search in the base model if the property is not found in the current model - return property ?? _model.BaseModelProvider?.Properties.FirstOrDefault( - p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); + return property ?? EnumerateBaseModelProviders() + .SelectMany(m => m.Properties) + .FirstOrDefault(p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); } private MethodProvider? FindCustomHookMethod(string hookName) { - var model = _model; - while (model != null) + foreach (var model in EnumerateModelAndBaseModelProviders()) { var method = model.CanonicalView.Methods.FirstOrDefault(m => m.Signature.Name == hookName); if (method != null) { return method; } - model = model.BaseModelProvider; } return null; } @@ -2636,9 +2630,8 @@ private List GetSerializationAttributes() List serializationAttributes = _model.CustomCodeView?.Attributes .Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName) .ToList() ?? []; - var baseModelProvider = _model.BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { var customCodeView = baseModelProvider.CustomCodeView; if (customCodeView != null) @@ -2647,12 +2640,33 @@ private List GetSerializationAttributes() .AddRange(customCodeView.Attributes .Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName)); } - baseModelProvider = baseModelProvider.BaseModelProvider; } return serializationAttributes; } + private IEnumerable EnumerateModelAndBaseModelProviders() + { + var visited = new HashSet(); + var model = _model; + while (model != null && visited.Add(model)) + { + yield return model; + model = model.BaseModelProvider; + } + } + + private IEnumerable EnumerateBaseModelProviders() + { + var visited = new HashSet { _model }; + var model = _model.BaseModelProvider; + while (model != null && visited.Add(model)) + { + yield return model; + model = model.BaseModelProvider; + } + } + private static bool TypeRequiresNullCheckInSerialization(CSharpType type) { if (type.IsCollection) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs index c8b4adbb74e..2af6016e335 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs @@ -59,6 +59,7 @@ protected override FormattableString BuildDescription() private readonly Type _additionalPropsUnknownType = typeof(BinaryData); private Lazy _useObjectAdditionalProperties; private FieldProvider? _rawDataField; + private bool _buildingRawDataField; private List? _additionalPropertyFields; private List? _additionalPropertyProperties; private ModelProvider? _baseModelProvider; @@ -178,7 +179,31 @@ public override void Reset() _fullConstructor = null; } - protected FieldProvider? RawDataField => _rawDataField ??= BuildRawDataField(); + protected FieldProvider? RawDataField + { + get + { + if (_rawDataField is not null) + { + return _rawDataField; + } + + if (_buildingRawDataField) + { + return null; + } + + _buildingRawDataField = true; + try + { + return _rawDataField = BuildRawDataField(); + } + finally + { + _buildingRawDataField = false; + } + } + } protected virtual bool ShouldSkipDerivedModelProperties => false; /// /// Gets whether derived models should skip overriding serialization methods from this base model. @@ -648,8 +673,9 @@ private IEnumerable EnumerateBaseModels() private IEnumerable EnumerateBaseModelProviders() { + HashSet visited = [this]; var model = BaseModelProvider; - while (model != null) + while (model != null && visited.Add(model)) { yield return model; model = model.BaseModelProvider; @@ -859,9 +885,8 @@ private bool ParametersMatch(IReadOnlyList params1, IReadOnly private IEnumerable GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false) { var properties = new Stack>(); - var modelProvider = BaseModelProvider; bool isDirectBase = true; - while (modelProvider != null) + foreach (var modelProvider in EnumerateBaseModelProviders()) { properties.Push([]); foreach (var property in modelProvider.CanonicalView.Properties) @@ -880,7 +905,6 @@ private IEnumerable GetAllBasePropertiesForConstructorInitiali } } - modelProvider = modelProvider.BaseModelProvider; isDirectBase = false; } @@ -891,15 +915,13 @@ private IEnumerable GetAllBasePropertiesForConstructorInitiali private IEnumerable GetAllBaseFieldsForConstructorInitialization() { var fields = new Stack>(); - var modelProvider = BaseModelProvider; - while (modelProvider != null) + foreach (var modelProvider in EnumerateBaseModelProviders()) { fields.Push([]); foreach (var field in modelProvider.CanonicalView.Fields) { fields.Peek().Add(field); } - modelProvider = modelProvider.BaseModelProvider; } return fields.SelectMany(l => l); @@ -918,7 +940,7 @@ private IEnumerable GetAllBaseFieldsForConstructorInitialization( baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter); baseFields = GetAllBaseFieldsForConstructorInitialization(); } - else if (BaseModelProvider?.FullConstructor.Signature != null) + else if (BaseModelProvider is not null && !HasBaseModelProviderCycle()) { baseParameters.AddRange(BaseModelProvider.FullConstructor.Signature.Parameters); } @@ -1003,6 +1025,23 @@ p.Property is null return (constructorParameters, constructorInitializer); } + private bool HasBaseModelProviderCycle() + { + HashSet visited = [this]; + var modelProvider = BaseModelProvider; + while (modelProvider != null) + { + if (!visited.Add(modelProvider)) + { + return true; + } + + modelProvider = modelProvider.BaseModelProvider; + } + + return false; + } + private ValueExpression? EnsureDiscriminatorValueExpression() { if (_inputModel.BaseModel is not null && _inputModel.DiscriminatorValue is not null) @@ -1300,14 +1339,12 @@ private static ValueExpression GetConversion(PropertyProvider? property = defaul } // check if there is a raw data field on any of the base models, if so, we do not have to have one here. - var baseModelProvider = BaseModelProvider; - while (baseModelProvider != null) + foreach (var baseModelProvider in EnumerateBaseModelProviders()) { if (baseModelProvider.RawDataField != null) { return null; } - baseModelProvider = baseModelProvider.BaseModelProvider; } var modifiers = FieldModifiers.Private;