Skip to content

Commit 42b428d

Browse files
committed
feat(generator): add special handling for discriminator cases
1 parent 577eb87 commit 42b428d

File tree

7 files changed

+326
-7
lines changed

7 files changed

+326
-7
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# openapi-request-validation
22

3-
Build-time OpenAPI request validation for serverless and lightweight Node runtimes. Performs request validation based on
4-
OpenAPI specifications. Supports AWS Lambda API Gateway events, lambda-api requests and Express.
3+
Provides request validation based on OpenAPI specifications for serverless and lightweight Node runtimes.
4+
5+
Used to generate validator code based on an OpenAPI spec and offers runtime validation helpers to validate requests objects.
6+
Supports AWS Lambda API Gateway events, lambda-api requests and Express.
57

68
Uses AJV to generate standalone request validators and ships runtime helpers for easy validation.
79

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import type {OpenAPIV3} from "openapi-types"
2+
import {resolveLocalRef} from "./validator-artifacts.js"
3+
4+
function isObjectRecord(value: unknown): value is Record<string, unknown> {
5+
return value !== null && typeof value === "object" && !Array.isArray(value)
6+
}
7+
8+
function isReferenceObject(value: unknown): value is OpenAPIV3.ReferenceObject {
9+
return isObjectRecord(value) && typeof value.$ref === "string"
10+
}
11+
12+
function extractStringLiteralValues(
13+
doc: OpenAPIV3.Document,
14+
schema: unknown,
15+
): string[] {
16+
if (!isObjectRecord(schema)) {
17+
return []
18+
}
19+
20+
const resolved = isReferenceObject(schema)
21+
? resolveLocalRef<Record<string, unknown>>(doc, schema.$ref)
22+
: schema
23+
24+
if (typeof resolved.const === "string") {
25+
return [resolved.const]
26+
}
27+
28+
if (Array.isArray(resolved.enum) && resolved.enum.every((value) => typeof value === "string")) {
29+
return [...resolved.enum]
30+
}
31+
32+
return []
33+
}
34+
35+
function inferDiscriminatorValues(
36+
doc: OpenAPIV3.Document,
37+
schema: OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject,
38+
propertyName: string,
39+
): string[] {
40+
const resolved = isReferenceObject(schema)
41+
? resolveLocalRef<Record<string, unknown>>(doc, schema.$ref)
42+
: schema as Record<string, unknown>
43+
44+
const properties = isObjectRecord(resolved.properties) ? resolved.properties : null
45+
const directValues = extractStringLiteralValues(doc, properties?.[propertyName])
46+
if (directValues.length > 0) {
47+
return directValues
48+
}
49+
50+
if (!isReferenceObject(schema)) {
51+
return []
52+
}
53+
54+
const segments = schema.$ref.split("/")
55+
const schemaName = segments.at(-1)
56+
return schemaName ? [schemaName] : []
57+
}
58+
59+
type DiscriminatorBranch = {
60+
tagValues: string[]
61+
schema: OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject
62+
}
63+
64+
function getDiscriminatorBranches(
65+
doc: OpenAPIV3.Document,
66+
unionMembers: Array<OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject>,
67+
discriminator: OpenAPIV3.DiscriminatorObject,
68+
): DiscriminatorBranch[] {
69+
const mappingEntries = Object.entries(discriminator.mapping ?? {})
70+
const branches: DiscriminatorBranch[] = []
71+
const seenTagValues = new Set<string>()
72+
73+
for (const member of unionMembers) {
74+
const tagValues = new Set<string>()
75+
76+
if (isReferenceObject(member)) {
77+
for (const [tagValue, mappedRef] of mappingEntries) {
78+
if (mappedRef === member.$ref) {
79+
tagValues.add(tagValue)
80+
}
81+
}
82+
}
83+
84+
if (tagValues.size === 0) {
85+
for (const tagValue of inferDiscriminatorValues(doc, member, discriminator.propertyName)) {
86+
tagValues.add(tagValue)
87+
}
88+
}
89+
90+
if (tagValues.size === 0) {
91+
return []
92+
}
93+
94+
for (const tagValue of tagValues) {
95+
if (seenTagValues.has(tagValue)) {
96+
return []
97+
}
98+
99+
seenTagValues.add(tagValue)
100+
}
101+
102+
branches.push({
103+
tagValues: [...tagValues],
104+
schema: member,
105+
})
106+
}
107+
108+
return branches
109+
}
110+
111+
function createDiscriminatorMatcher(
112+
propertyName: string,
113+
tagValues: string[],
114+
): Record<string, unknown> {
115+
return {
116+
type: "object",
117+
required: [propertyName],
118+
properties: {
119+
[propertyName]: tagValues.length === 1 ? {const: tagValues[0]} : {enum: tagValues},
120+
},
121+
}
122+
}
123+
124+
function rewriteDiscriminatorSchema(
125+
doc: OpenAPIV3.Document,
126+
schema: Record<string, unknown>,
127+
): void {
128+
const discriminator = schema.discriminator
129+
if (!isObjectRecord(discriminator) || typeof discriminator.propertyName !== "string") {
130+
return
131+
}
132+
133+
const propertyName = discriminator.propertyName
134+
135+
const oneOf = Array.isArray(schema.oneOf) ? schema.oneOf : null
136+
const anyOf = Array.isArray(schema.anyOf) ? schema.anyOf : null
137+
const unionMembers = oneOf ?? anyOf
138+
139+
if (!unionMembers) {
140+
return
141+
}
142+
143+
const branches = getDiscriminatorBranches(
144+
doc,
145+
unionMembers as Array<OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject>,
146+
discriminator as unknown as OpenAPIV3.DiscriminatorObject,
147+
)
148+
149+
if (branches.length === 0) {
150+
return
151+
}
152+
153+
const existingAllOf = Array.isArray(schema.allOf) ? [...schema.allOf] : []
154+
const allowedTagValues = branches.flatMap((branch) => branch.tagValues)
155+
156+
schema.allOf = [
157+
...existingAllOf,
158+
createDiscriminatorMatcher(propertyName, allowedTagValues),
159+
...branches.map((branch) => ({
160+
if: createDiscriminatorMatcher(propertyName, branch.tagValues),
161+
then: branch.schema,
162+
})),
163+
] as Array<OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject>
164+
165+
delete schema.discriminator
166+
delete schema.oneOf
167+
delete schema.anyOf
168+
}
169+
170+
export function rewriteOpenApiDiscriminators(doc: OpenAPIV3.Document): void {
171+
const visit = (value: unknown) => {
172+
if (Array.isArray(value)) {
173+
for (const item of value) {
174+
visit(item)
175+
}
176+
return
177+
}
178+
179+
if (!isObjectRecord(value)) {
180+
return
181+
}
182+
183+
for (const child of Object.values(value)) {
184+
visit(child)
185+
}
186+
187+
rewriteDiscriminatorSchema(doc, value)
188+
}
189+
190+
visit(doc)
191+
}

src/generator/validator-artifacts.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type {OpenAPIV3} from "openapi-types"
22
import {HTTP_METHODS} from "../common/http-methods.js"
33
import {toPascalIdentifier} from "./pascal-identifier.js"
4+
import {rewriteOpenApiDiscriminators} from "./rewrite-open-api-discriminators.js"
45

56
type ParameterLocation = "query" | "path" | "header"
67

@@ -399,6 +400,7 @@ export function prepareValidatorArtifacts(
399400
validators: ValidatorExports
400401
} {
401402
const specForValidation = structuredClone(spec)
403+
rewriteOpenApiDiscriminators(specForValidation)
402404
const validators = registerRequestSchemasAndCollectValidators(specForValidation)
403405

404406
return {

test/discriminator-mapping.test.ts

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import {RequestValidationError, validateApiGatewayEvent} from "../src"
2+
import {loadGeneratedValidators} from "./helpers/load-generated-validator"
3+
4+
describe("discriminator mappings", () => {
5+
let cleanup: () => void = () => {
6+
}
7+
let validateCreateItemRequest: (value: unknown) => boolean
8+
9+
beforeAll(async () => {
10+
const loaded = await loadGeneratedValidators({fixtureName: "discriminator-mapping-api.yaml"})
11+
cleanup = loaded.cleanup
12+
validateCreateItemRequest = loaded.getValidator("validateCreateItemRequest")
13+
})
14+
15+
afterAll(() => {
16+
cleanup()
17+
})
18+
19+
test("accepts the mapped embedded branch even when the branch inherits the discriminator property", () => {
20+
expect(() =>
21+
validateApiGatewayEvent(validateCreateItemRequest, {
22+
body: JSON.stringify({
23+
item: {
24+
type: "EMBEDDED",
25+
code: "embedded-item",
26+
value: 25
27+
},
28+
}),
29+
}),
30+
).not.toThrow()
31+
})
32+
33+
test("only reports errors from the discriminator-selected branch", () => {
34+
let thrownError: unknown
35+
36+
try {
37+
validateApiGatewayEvent(validateCreateItemRequest, {
38+
body: JSON.stringify({
39+
item: {
40+
type: "EMBEDDED",
41+
},
42+
}),
43+
})
44+
} catch (error) {
45+
thrownError = error
46+
}
47+
48+
expect(thrownError).toBeInstanceOf(RequestValidationError)
49+
50+
const validationError = thrownError as RequestValidationError
51+
const missingProperties = validationError.errors
52+
.map((error) => error.params)
53+
.filter((params): params is {missingProperty: string} => typeof params === "object" && params !== null && "missingProperty" in params)
54+
.map((params) => params.missingProperty)
55+
.sort()
56+
57+
expect(missingProperties).toEqual(["code"])
58+
})
59+
})
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
openapi: 3.0.3
2+
info:
3+
title: Discriminator Mapping API
4+
version: 1.0.0
5+
paths:
6+
/items:
7+
post:
8+
operationId: createItem
9+
requestBody:
10+
required: true
11+
content:
12+
application/json:
13+
schema:
14+
type: object
15+
required:
16+
- item
17+
properties:
18+
item:
19+
$ref: "#/components/schemas/ItemResource"
20+
responses:
21+
"201":
22+
description: Created
23+
components:
24+
schemas:
25+
ItemResource:
26+
oneOf:
27+
- $ref: "#/components/schemas/ItemReference"
28+
- $ref: "#/components/schemas/ItemEmbedded"
29+
discriminator:
30+
propertyName: type
31+
mapping:
32+
REFERENCE: "#/components/schemas/ItemReference"
33+
EMBEDDED: "#/components/schemas/ItemEmbedded"
34+
ItemResourceType:
35+
type: string
36+
enum:
37+
- REFERENCE
38+
- EMBEDDED
39+
ItemResourceBase:
40+
type: object
41+
required:
42+
- type
43+
properties:
44+
type:
45+
$ref: "#/components/schemas/ItemResourceType"
46+
ItemReference:
47+
allOf:
48+
- $ref: "#/components/schemas/ItemResourceBase"
49+
- type: object
50+
required:
51+
- id
52+
properties:
53+
id:
54+
type: string
55+
ItemEmbedded:
56+
allOf:
57+
- $ref: "#/components/schemas/ItemResourceBase"
58+
- type: object
59+
required:
60+
- code
61+
properties:
62+
code:
63+
type: string
64+
value:
65+
type: number

test/package-exports.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ describe("package exports", () => {
2727
const esmOutput = runNode([
2828
"--input-type=module",
2929
"-e",
30-
"Promise.all([import('openapi-request-validation/generator'), import('openapi-request-validation/runtime')]).then(([g,r])=>console.log(typeof g.generateApiValidatorsFromDocument + ' ' + typeof g.generateApiValidatorsFromPath + ' ' + typeof r.validateExpressRequest + ' ' + typeof r.validateApiGatewayProxyEvent))",
30+
"Promise.all([import('openapi-request-validation/generator'), import('openapi-request-validation/runtime')]).then(async ([g,r])=>{const generated=await g.generateApiValidatorsFromPath('./test/fixtures/order-api.yaml'); console.log(typeof g.generateApiValidatorsFromDocument + ' ' + typeof g.generateApiValidatorsFromPath + ' ' + typeof r.validateExpressRequest + ' ' + typeof r.validateApiGatewayProxyEvent + ' ' + generated.validatorCount)})",
3131
])
3232
const cjsOutput = runNode([
3333
"-e",
34-
"const g=require('openapi-request-validation/generator'); const r=require('openapi-request-validation/runtime'); console.log(typeof g.generateApiValidatorsFromDocument + ' ' + typeof g.generateApiValidatorsFromPath + ' ' + typeof r.validateExpressRequest + ' ' + typeof r.validateApiGatewayProxyEvent)",
34+
"const g=require('openapi-request-validation/generator'); const r=require('openapi-request-validation/runtime'); g.generateApiValidatorsFromPath('./test/fixtures/order-api.yaml').then((generated)=>console.log(typeof g.generateApiValidatorsFromDocument + ' ' + typeof g.generateApiValidatorsFromPath + ' ' + typeof r.validateExpressRequest + ' ' + typeof r.validateApiGatewayProxyEvent + ' ' + generated.validatorCount))",
3535
])
3636

37-
expect(esmOutput).toBe("function function function function")
38-
expect(cjsOutput).toBe("function function function function")
37+
expect(esmOutput).toBe("function function function function 1")
38+
expect(cjsOutput).toBe("function function function function 1")
3939
})
4040
})

test/validate-composition-schemas.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,6 @@ describe("OpenAPI composition schemas", () => {
141141
pushToken: "device-token",
142142
}),
143143
}),
144-
).toThrow("must match exactly one schema in oneOf")
144+
).toThrow("must be equal to one of the allowed values")
145145
})
146146
})

0 commit comments

Comments
 (0)