Skip to content

Commit 6358607

Browse files
authored
Merge pull request #29 from AVSystem/feat/nested-sealed-hierarchies
feat: nested sealed hierarchies for oneOf/anyOf subtypes
2 parents 34d4621 + 268c7d2 commit 6358607

19 files changed

Lines changed: 772 additions & 284 deletions

File tree

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package com.avsystem.justworks.core
2+
3+
import arrow.atomic.Atomic
4+
import arrow.atomic.update
5+
6+
@JvmInline
7+
value class MemoScope private constructor(private val memos: Atomic<Set<Memo<*>>>) {
8+
constructor() : this(Atomic(emptySet()))
9+
10+
fun register(memo: Memo<*>) {
11+
memos.update { it + memo }
12+
}
13+
14+
fun reset() {
15+
memos.get().forEach { it.reset() }
16+
}
17+
}
18+
19+
class Memo<T>(private val compute: () -> T) {
20+
private val holder = Atomic(lazy(compute))
21+
22+
operator fun getValue(thisRef: Any?, property: Any?): T = holder.get().value
23+
24+
fun reset() {
25+
holder.set(lazy(compute))
26+
}
27+
}
28+
29+
fun <T> memoized(memoScope: MemoScope, compute: () -> T): Memo<T> = Memo(compute).also(memoScope::register)

core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ package com.avsystem.justworks.core
22

33
import kotlin.enums.enumEntries
44

5+
internal const val SCHEMA_PREFIX = "#/components/schemas/"
6+
57
inline fun <reified T : Enum<T>> String.toEnumOrNull(): T? = enumEntries<T>().find { it.name.equals(this, true) }

core/src/main/kotlin/com/avsystem/justworks/core/gen/CodeGenerator.kt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@ object CodeGenerator {
1919
modelPackage: String,
2020
apiPackage: String,
2121
outputDir: File,
22-
): Result = context(ModelPackage(modelPackage), ApiPackage(apiPackage)) {
23-
val modelRegistry = NameRegistry()
24-
val apiRegistry = NameRegistry()
22+
): Result {
23+
val hierarchy = Hierarchy(ModelPackage(modelPackage)).apply { addSchemas(spec.schemas) }
2524

26-
val (modelFiles, resolvedSpec) = ModelGenerator.generateWithResolvedSpec(spec, modelRegistry)
25+
val (modelFiles, resolvedSpec) = context(hierarchy, NameRegistry()) {
26+
ModelGenerator.generateWithResolvedSpec(spec)
27+
}
2728

2829
modelFiles.forEach { it.writeTo(outputDir) }
2930

3031
val hasPolymorphicTypes = modelFiles.any { it.name == SERIALIZERS_MODULE.simpleName }
3132

32-
val clientFiles = ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes, apiRegistry)
33+
val clientFiles = context(hierarchy, ApiPackage(apiPackage), NameRegistry()) {
34+
ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes)
35+
}
3336

3437
clientFiles.forEach { it.writeTo(outputDir) }
3538

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package com.avsystem.justworks.core.gen
2+
3+
import com.avsystem.justworks.core.MemoScope
4+
import com.avsystem.justworks.core.memoized
5+
import com.avsystem.justworks.core.model.SchemaModel
6+
import com.avsystem.justworks.core.model.TypeRef
7+
import com.squareup.kotlinpoet.ClassName
8+
9+
internal class Hierarchy(val modelPackage: ModelPackage) {
10+
private val schemas = mutableSetOf<SchemaModel>()
11+
12+
private val memoScope = MemoScope()
13+
14+
/**
15+
* Updates the underlying schemas and invalidates all cached derived views.
16+
* This is necessary when schemas are updated (e.g., after inlining types).
17+
*/
18+
fun addSchemas(newSchemas: List<SchemaModel>) {
19+
schemas += newSchemas
20+
memoScope.reset()
21+
}
22+
23+
/** All schemas indexed by name for quick lookup. */
24+
val schemasById: Map<String, SchemaModel> by memoized(memoScope) {
25+
schemas.associateBy { it.name }
26+
}
27+
28+
/** Schemas that define polymorphic variants via oneOf or anyOf. */
29+
private val polymorphicSchemas: List<SchemaModel> by memoized(memoScope) {
30+
schemas.filterNot { it.variants().isNullOrEmpty() }
31+
}
32+
33+
/** Maps parent schema name to its variant schema names (for both oneOf and anyOf). */
34+
val sealedHierarchies: Map<String, List<String>> by memoized(memoScope) {
35+
polymorphicSchemas
36+
.associate { schema ->
37+
schema.name to schema
38+
.variants()
39+
?.filterIsInstance<TypeRef.Reference>()
40+
?.map { it.schemaName }
41+
.orEmpty()
42+
}
43+
}
44+
45+
/** Parent schema names that use anyOf without a discriminator (JsonContentPolymorphicSerializer pattern). */
46+
val anyOfWithoutDiscriminator: Set<String> by memoized(memoScope) {
47+
polymorphicSchemas
48+
.asSequence()
49+
.filter { !it.anyOf.isNullOrEmpty() && it.discriminator == null }
50+
.map { it.name }
51+
.toSet()
52+
}
53+
54+
/** Inverse of [sealedHierarchies] for anyOf-without-discriminator: variant name to its parent names. */
55+
val anyOfParents: Map<String, Set<String>> by memoized(memoScope) {
56+
sealedHierarchies
57+
.asSequence()
58+
.filter { (parent, _) -> parent in anyOfWithoutDiscriminator }
59+
.flatMap { (parent, variants) -> variants.map { it to parent } }
60+
.groupBy({ it.first }, { it.second })
61+
.mapValues { (_, parents) -> parents.toSet() }
62+
}
63+
64+
/** Maps schema name to its [ClassName], using nested class for discriminated hierarchy variants. */
65+
private val lookup: Map<String, ClassName> by memoized(memoScope) {
66+
sealedHierarchies
67+
.asSequence()
68+
.filterNot { (parent, _) -> parent in anyOfWithoutDiscriminator }
69+
.flatMap { (parent, variants) ->
70+
val parentClass = ClassName(modelPackage, parent)
71+
variants.map { variant -> variant to parentClass.nestedClass(variant.toPascalCase()) } +
72+
(parent to parentClass)
73+
}.toMap()
74+
}
75+
76+
/** Resolves a schema name to its [ClassName], falling back to a flat top-level class. */
77+
operator fun get(name: String): ClassName = lookup[name] ?: ClassName(modelPackage, name)
78+
}
79+
80+
private fun SchemaModel.variants() = oneOf ?: anyOf

core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package com.avsystem.justworks.core.gen
22

3+
import com.avsystem.justworks.core.SCHEMA_PREFIX
34
import com.avsystem.justworks.core.model.PrimitiveType
45
import com.avsystem.justworks.core.model.PropertyModel
6+
import com.avsystem.justworks.core.model.SchemaModel
57
import com.avsystem.justworks.core.model.TypeRef
68
import com.squareup.kotlinpoet.BOOLEAN
79
import com.squareup.kotlinpoet.BYTE_ARRAY
8-
import com.squareup.kotlinpoet.ClassName
910
import com.squareup.kotlinpoet.DOUBLE
1011
import com.squareup.kotlinpoet.FLOAT
1112
import com.squareup.kotlinpoet.INT
@@ -28,7 +29,7 @@ internal val TypeRef.requiredProperties: Set<String>
2829
is TypeRef.Array, is TypeRef.Map, is TypeRef.Primitive, is TypeRef.Reference, TypeRef.Unknown -> emptySet()
2930
}
3031

31-
context(modelPackage: ModelPackage)
32+
context(hierarchy: Hierarchy)
3233
internal fun TypeRef.toTypeName(): TypeName = when (this) {
3334
is TypeRef.Primitive -> {
3435
when (type) {
@@ -54,7 +55,7 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) {
5455
}
5556

5657
is TypeRef.Reference -> {
57-
ClassName(modelPackage, schemaName)
58+
hierarchy[schemaName]
5859
}
5960

6061
is TypeRef.Inline -> {
@@ -67,3 +68,13 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) {
6768
}
6869

6970
internal fun TypeRef.isBinaryUpload(): Boolean = this is TypeRef.Primitive && this.type == PrimitiveType.BYTE_ARRAY
71+
72+
/**
73+
* Resolves the @SerialName value for a variant within a oneOf schema.
74+
*/
75+
internal fun SchemaModel.resolveSerialName(variantSchemaName: String): String = discriminator
76+
?.mapping
77+
?.firstNotNullOfOrNull { (serialName, refPath) ->
78+
serialName.takeIf { refPath.removePrefix(SCHEMA_PREFIX) == variantSchemaName }
79+
}
80+
?: variantSchemaName

core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ClientGenerator.kt

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import com.avsystem.justworks.core.gen.GENERATED_SERIALIZERS_MODULE
99
import com.avsystem.justworks.core.gen.HTTP_CLIENT
1010
import com.avsystem.justworks.core.gen.HTTP_ERROR
1111
import com.avsystem.justworks.core.gen.HTTP_SUCCESS
12-
import com.avsystem.justworks.core.gen.ModelPackage
12+
import com.avsystem.justworks.core.gen.Hierarchy
1313
import com.avsystem.justworks.core.gen.NameRegistry
1414
import com.avsystem.justworks.core.gen.RAISE
1515
import com.avsystem.justworks.core.gen.TOKEN
@@ -51,29 +51,22 @@ internal object ClientGenerator {
5151
private const val DEFAULT_TAG = "Default"
5252
private const val API_SUFFIX = "Api"
5353

54-
context(_: ModelPackage, _: ApiPackage)
55-
fun generate(
56-
spec: ApiSpec,
57-
hasPolymorphicTypes: Boolean,
58-
nameRegistry: NameRegistry,
59-
): List<FileSpec> {
54+
context(_: Hierarchy, _: ApiPackage, _: NameRegistry)
55+
fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean): List<FileSpec> {
6056
val grouped = spec.endpoints.groupBy { it.tags.firstOrNull() ?: DEFAULT_TAG }
61-
return grouped.map { (tag, endpoints) ->
62-
generateClientFile(tag, endpoints, hasPolymorphicTypes, nameRegistry)
63-
}
57+
return grouped.map { (tag, endpoints) -> generateClientFile(tag, endpoints, hasPolymorphicTypes) }
6458
}
6559

66-
context(modelPackage: ModelPackage, apiPackage: ApiPackage)
60+
context(hierarchy: Hierarchy, apiPackage: ApiPackage, nameRegistry: NameRegistry)
6761
private fun generateClientFile(
6862
tag: String,
6963
endpoints: List<Endpoint>,
7064
hasPolymorphicTypes: Boolean,
71-
nameRegistry: NameRegistry,
7265
): FileSpec {
7366
val className = ClassName(apiPackage, nameRegistry.register("${tag.toPascalCase()}$API_SUFFIX"))
7467

7568
val clientInitializer = if (hasPolymorphicTypes) {
76-
val generatedSerializersModule = MemberName(modelPackage, GENERATED_SERIALIZERS_MODULE)
69+
val generatedSerializersModule = MemberName(hierarchy.modelPackage, GENERATED_SERIALIZERS_MODULE)
7770
CodeBlock.of("${CREATE_HTTP_CLIENT}(%M)", generatedSerializersModule)
7871
} else {
7972
CodeBlock.of("${CREATE_HTTP_CLIENT}()")
@@ -101,17 +94,18 @@ internal object ClientGenerator {
10194
.primaryConstructor(primaryConstructor)
10295
.addProperty(httpClientProperty)
10396

104-
val methodRegistry = NameRegistry()
105-
classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it, methodRegistry) })
97+
context(NameRegistry()) {
98+
classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it) })
99+
}
106100

107101
return FileSpec
108102
.builder(className)
109103
.addType(classBuilder.build())
110104
.build()
111105
}
112106

113-
context(_: ModelPackage)
114-
private fun generateEndpointFunction(endpoint: Endpoint, methodRegistry: NameRegistry): FunSpec {
107+
context(_: Hierarchy, methodRegistry: NameRegistry)
108+
private fun generateEndpointFunction(endpoint: Endpoint): FunSpec {
115109
val functionName = methodRegistry.register(endpoint.operationId.toCamelCase())
116110
val returnBodyType = resolveReturnType(endpoint)
117111
val returnType = HTTP_SUCCESS.parameterizedBy(returnBodyType)
@@ -166,7 +160,7 @@ internal object ClientGenerator {
166160
return funBuilder.build()
167161
}
168162

169-
context(_: ModelPackage)
163+
context(_: Hierarchy)
170164
private fun resolveReturnType(endpoint: Endpoint): TypeName = endpoint.responses.entries
171165
.asSequence()
172166
.filter { it.key.startsWith("2") }

core/src/main/kotlin/com/avsystem/justworks/core/gen/client/ParametersGenerator.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.avsystem.justworks.core.gen.client
33
import com.avsystem.justworks.core.gen.BODY
44
import com.avsystem.justworks.core.gen.CHANNEL_PROVIDER
55
import com.avsystem.justworks.core.gen.CONTENT_TYPE_CLASS
6-
import com.avsystem.justworks.core.gen.ModelPackage
6+
import com.avsystem.justworks.core.gen.Hierarchy
77
import com.avsystem.justworks.core.gen.isBinaryUpload
88
import com.avsystem.justworks.core.gen.properties
99
import com.avsystem.justworks.core.gen.requiredProperties
@@ -16,7 +16,7 @@ import com.squareup.kotlinpoet.ParameterSpec
1616
import com.squareup.kotlinpoet.STRING
1717

1818
internal object ParametersGenerator {
19-
context(_: ModelPackage)
19+
context(_: Hierarchy)
2020
fun buildMultipartParameters(requestBody: RequestBody): List<ParameterSpec> =
2121
requestBody.schema.properties.flatMap { prop ->
2222
val name = prop.name.toCamelCase()
@@ -33,13 +33,13 @@ internal object ParametersGenerator {
3333
}
3434
}
3535

36-
context(_: ModelPackage)
36+
context(_: Hierarchy)
3737
fun buildFormParameters(requestBody: RequestBody): List<ParameterSpec> = requestBody.schema.properties.map { prop ->
3838
val isRequired = requestBody.required && prop.name in requestBody.schema.requiredProperties
3939
buildNullableParameter(prop.type, prop.name, isRequired)
4040
}
4141

42-
context(_: ModelPackage)
42+
context(_: Hierarchy)
4343
fun buildNullableParameter(
4444
typeRef: TypeRef,
4545
name: String,
@@ -50,7 +50,7 @@ internal object ParametersGenerator {
5050
return builder.build()
5151
}
5252

53-
context(_: ModelPackage)
53+
context(_: Hierarchy)
5454
fun buildBodyParams(requestBody: RequestBody) = when (requestBody.contentType) {
5555
ContentType.MULTIPART_FORM_DATA -> buildMultipartParameters(requestBody)
5656
ContentType.FORM_URL_ENCODED -> buildFormParameters(requestBody)

0 commit comments

Comments
 (0)