Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ protected override IReadOnlyList<MethodBodyStatement> BuildAttributes()

// Process all providers from the output library to discover types from methods and properties
var providers = ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders;
var outputLibraryProviders = new HashSet<TypeProvider>(providers, s_typeProviderNameComparer);

// Process each provider recursively
foreach (var provider in providers)
{
CollectBuildableTypeProvidersRecursive(provider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
CollectBuildableTypeProvidersRecursive(provider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
}

return (buildableTypes, buildableProviders);
Expand All @@ -120,6 +121,7 @@ private void CollectBuildableTypesRecursive(
/// </summary>
private void CollectBuildableTypeProvidersRecursive(
TypeProvider currentProvider,
HashSet<TypeProvider> outputLibraryProviders,
HashSet<CSharpType> visitedTypes,
HashSet<TypeProvider> visitedTypeProviders,
HashSet<TypeProvider> buildableProviders,
Expand All @@ -131,21 +133,22 @@ private void CollectBuildableTypeProvidersRecursive(
return;
}

// Only add to buildableProviders if it implements MRW
if (ImplementsModelReaderWriter(currentProvider))
// Only add to buildableProviders if it implements MRW and belongs to the output library.
if (outputLibraryProviders.Contains(currentProvider) && ImplementsModelReaderWriter(currentProvider))
{
buildableProviders.Add(currentProvider);
}

// Process all providers to discover types from methods and properties
if (currentProvider is not null)
{
CollectBuildableTypesRecursiveCore(currentProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
CollectBuildableTypesRecursiveCore(currentProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
}
}

private void CollectBuildableTypesRecursiveCore(
TypeProvider provider,
HashSet<TypeProvider> outputLibraryProviders,
HashSet<CSharpType> visitedTypes,
HashSet<TypeProvider> visitedTypeProviders,
HashSet<TypeProvider> buildableProviders,
Expand Down Expand Up @@ -182,8 +185,7 @@ private void CollectBuildableTypesRecursiveCore(

if (provider is ModelProvider modelProvider && modelProvider.BaseModelProvider != null)
{
// For base model types, we need to process their properties as well, but we don't need to add the base model type itself
CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
CollectBuildableTypeProvidersRecursive(modelProvider.BaseModelProvider, outputLibraryProviders, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,16 +1114,14 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var rawBinaryData = _rawDataField;
if (rawBinaryData == null)
{
var baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
if (field != null)
{
rawBinaryData = field;
break;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}
}

Expand Down Expand Up @@ -1737,8 +1735,7 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode

if (isDynamicModelWithNonDynamicBase)
{
var baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
foreach (var property in baseModelProvider.CanonicalView.Properties)
{
Expand All @@ -1759,8 +1756,6 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode

propertyStatements.Add(CreateWritePropertyStatement(field.WireInfo, field.Type, field.Name, field, field.WireInfo?.SerializationFormat));
}

baseModelProvider = baseModelProvider.BaseModelProvider;
}
}

Expand Down Expand Up @@ -2575,21 +2570,20 @@ private MethodBodyStatement CreateWriteAdditionalPropertiesStatement()
PropertyProvider? property = _model.Properties.FirstOrDefault(
p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
// search in the base model if the property is not found in the current model
return property ?? _model.BaseModelProvider?.Properties.FirstOrDefault(
p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
return property ?? EnumerateBaseModelProviders()
.SelectMany(m => m.Properties)
.FirstOrDefault(p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
}

private MethodProvider? FindCustomHookMethod(string hookName)
{
var model = _model;
while (model != null)
foreach (var model in EnumerateModelAndBaseModelProviders())
{
var method = model.CanonicalView.Methods.FirstOrDefault(m => m.Signature.Name == hookName);
if (method != null)
{
return method;
}
model = model.BaseModelProvider;
}
return null;
}
Expand Down Expand Up @@ -2636,9 +2630,8 @@ private List<AttributeStatement> GetSerializationAttributes()
List<AttributeStatement> serializationAttributes = _model.CustomCodeView?.Attributes
.Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName)
.ToList() ?? [];
var baseModelProvider = _model.BaseModelProvider;

while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
var customCodeView = baseModelProvider.CustomCodeView;
if (customCodeView != null)
Expand All @@ -2647,12 +2640,33 @@ private List<AttributeStatement> GetSerializationAttributes()
.AddRange(customCodeView.Attributes
.Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName));
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

return serializationAttributes;
}

private IEnumerable<ModelProvider> EnumerateModelAndBaseModelProviders()
{
var visited = new HashSet<ModelProvider>();
var model = _model;
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
}
}

private IEnumerable<ModelProvider> EnumerateBaseModelProviders()
{
var visited = new HashSet<ModelProvider> { _model };
var model = _model.BaseModelProvider;
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
}
}

private static bool TypeRequiresNullCheckInSerialization(CSharpType type)
{
if (type.IsCollection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ protected override FormattableString BuildDescription()
private readonly Type _additionalPropsUnknownType = typeof(BinaryData);
private Lazy<bool> _useObjectAdditionalProperties;
private FieldProvider? _rawDataField;
private bool _buildingRawDataField;
private List<FieldProvider>? _additionalPropertyFields;
private List<PropertyProvider>? _additionalPropertyProperties;
private ModelProvider? _baseModelProvider;
Expand Down Expand Up @@ -178,7 +179,31 @@ public override void Reset()
_fullConstructor = null;
}

protected FieldProvider? RawDataField => _rawDataField ??= BuildRawDataField();
protected FieldProvider? RawDataField
{
get
{
if (_rawDataField is not null)
{
return _rawDataField;
}

if (_buildingRawDataField)
{
return null;
}

_buildingRawDataField = true;
try
{
return _rawDataField = BuildRawDataField();
}
finally
{
_buildingRawDataField = false;
}
}
}
protected virtual bool ShouldSkipDerivedModelProperties => false;
/// <summary>
/// Gets whether derived models should skip overriding serialization methods from this base model.
Expand Down Expand Up @@ -648,8 +673,9 @@ private IEnumerable<InputModelType> EnumerateBaseModels()

private IEnumerable<ModelProvider> EnumerateBaseModelProviders()
{
HashSet<ModelProvider> visited = [this];
var model = BaseModelProvider;
while (model != null)
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
Expand Down Expand Up @@ -859,9 +885,8 @@ private bool ParametersMatch(IReadOnlyList<ParameterProvider> params1, IReadOnly
private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false)
{
var properties = new Stack<List<PropertyProvider>>();
var modelProvider = BaseModelProvider;
bool isDirectBase = true;
while (modelProvider != null)
foreach (var modelProvider in EnumerateBaseModelProviders())
{
properties.Push([]);
foreach (var property in modelProvider.CanonicalView.Properties)
Expand All @@ -880,7 +905,6 @@ private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitiali
}
}

modelProvider = modelProvider.BaseModelProvider;
isDirectBase = false;
}

Expand All @@ -891,15 +915,13 @@ private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitiali
private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization()
{
var fields = new Stack<List<FieldProvider>>();
var modelProvider = BaseModelProvider;
while (modelProvider != null)
foreach (var modelProvider in EnumerateBaseModelProviders())
{
fields.Push([]);
foreach (var field in modelProvider.CanonicalView.Fields)
{
fields.Peek().Add(field);
}
modelProvider = modelProvider.BaseModelProvider;
}

return fields.SelectMany(l => l);
Expand All @@ -918,7 +940,7 @@ private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization(
baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter);
baseFields = GetAllBaseFieldsForConstructorInitialization();
}
else if (BaseModelProvider?.FullConstructor.Signature != null)
else if (BaseModelProvider is not null && !HasBaseModelProviderCycle())
{
baseParameters.AddRange(BaseModelProvider.FullConstructor.Signature.Parameters);
}
Expand Down Expand Up @@ -1003,6 +1025,23 @@ p.Property is null
return (constructorParameters, constructorInitializer);
}

private bool HasBaseModelProviderCycle()
{
HashSet<ModelProvider> visited = [this];
var modelProvider = BaseModelProvider;
while (modelProvider != null)
{
if (!visited.Add(modelProvider))
{
return true;
}

modelProvider = modelProvider.BaseModelProvider;
}

return false;
}

private ValueExpression? EnsureDiscriminatorValueExpression()
{
if (_inputModel.BaseModel is not null && _inputModel.DiscriminatorValue is not null)
Expand Down Expand Up @@ -1300,14 +1339,12 @@ private static ValueExpression GetConversion(PropertyProvider? property = defaul
}

// check if there is a raw data field on any of the base models, if so, we do not have to have one here.
var baseModelProvider = BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
if (baseModelProvider.RawDataField != null)
{
return null;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

var modifiers = FieldModifiers.Private;
Expand Down
Loading