Skip to content

Commit 4d0040b

Browse files
author
Anders Brams
committed
fix: types are always defined before being used
1 parent 5043e14 commit 4d0040b

5 files changed

Lines changed: 126 additions & 30 deletions

File tree

openapi_python/generator/render.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -157,48 +157,47 @@ def _alias_dependencies(defn: TypeAliasDef, names: set[str]) -> set[str]:
157157
return dependencies
158158

159159

160-
def _order_aliases(defns: tuple[TypeAliasDef, ...]) -> list[TypeAliasDef]:
161-
by_name = {item.name: item for item in defns}
160+
def _type_dependencies(defn: TypeAliasDef | TypedDictDef, names: set[str]) -> set[str]:
161+
match defn:
162+
case TypeAliasDef():
163+
return _alias_dependencies(defn, names)
164+
case TypedDictDef():
165+
return _typed_dict_dependencies(defn, names)
166+
167+
168+
def _order_type_definitions(
169+
aliases: tuple[TypeAliasDef, ...], typed_dicts: tuple[TypedDictDef, ...]
170+
) -> list[TypeAliasDef | TypedDictDef]:
171+
by_name: dict[str, TypeAliasDef | TypedDictDef] = {
172+
item.name: item for item in (*aliases, *typed_dicts)
173+
}
162174
names = set(by_name)
163-
ordered: list[TypeAliasDef] = []
175+
ordered: list[TypeAliasDef | TypedDictDef] = []
164176
temporary: set[str] = set()
165177
permanent: set[str] = set()
166178

167179
def visit(name: str) -> None:
168180
if name in permanent or name in temporary:
169181
return
170182
temporary.add(name)
171-
for dependency in sorted(_alias_dependencies(by_name[name], names)):
183+
for dependency in sorted(_type_dependencies(by_name[name], names)):
172184
visit(dependency)
173185
temporary.remove(name)
174186
permanent.add(name)
175187
ordered.append(by_name[name])
176188

177189
for name in sorted(by_name):
178190
visit(name)
179-
return ordered
180-
181191

182-
def _order_typeddicts(defns: tuple[TypedDictDef, ...]) -> list[TypedDictDef]:
183-
by_name = {item.name: item for item in defns}
184-
names = set(by_name)
185-
ordered: list[TypedDictDef] = []
186-
temporary: set[str] = set()
187-
permanent: set[str] = set()
192+
return ordered
188193

189-
def visit(name: str) -> None:
190-
if name in permanent or name in temporary:
191-
return
192-
temporary.add(name)
193-
for dependency in sorted(_typed_dict_dependencies(by_name[name], names)):
194-
visit(dependency)
195-
temporary.remove(name)
196-
permanent.add(name)
197-
ordered.append(by_name[name])
198194

199-
for name in sorted(by_name):
200-
visit(name)
201-
return ordered
195+
def _format_type_definition(defn: TypeAliasDef | TypedDictDef) -> str:
196+
match defn:
197+
case TypeAliasDef():
198+
return _format_alias(defn)
199+
case TypedDictDef():
200+
return _format_typeddict(defn)
202201

203202

204203
def _call_parameters(op: OperationDef) -> dict[str, str]:
@@ -273,12 +272,11 @@ def _fallback_method_block(
273272

274273

275274
def _render_types(spec: NormalizedSpec) -> str:
276-
aliases = (*_route_aliases(spec), *_order_aliases(spec.aliases))
277-
blocks = (
278-
[_format_enum(item) for item in spec.enums]
279-
+ [_format_alias(alias) for alias in aliases]
280-
+ [_format_typeddict(item) for item in _order_typeddicts(spec.typed_dicts)]
281-
)
275+
aliases = (*_route_aliases(spec), *spec.aliases)
276+
type_definitions = _order_type_definitions(aliases, spec.typed_dicts)
277+
blocks = [_format_enum(item) for item in spec.enums] + [
278+
_format_type_definition(item) for item in type_definitions
279+
]
282280
return _render_template(
283281
"types.py.j2",
284282
type_blocks="\n".join(blocks).strip() + "\n",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from pathlib import Path
5+
6+
from openapi_python.generator import GenerationRequest, generate_client
7+
8+
SPEC = {
9+
"openapi": "3.1.0",
10+
"info": {"title": "Type order", "version": "1.0.0"},
11+
"paths": {
12+
"/owners": {
13+
"get": {
14+
"responses": {
15+
"200": {
16+
"content": {
17+
"application/json": {
18+
"schema": {"$ref": "#/components/schemas/Owner"}
19+
}
20+
}
21+
}
22+
}
23+
}
24+
}
25+
},
26+
"components": {
27+
"schemas": {
28+
"Owner": {
29+
"oneOf": [
30+
{"$ref": "#/components/schemas/Person"},
31+
{"$ref": "#/components/schemas/Team"},
32+
]
33+
},
34+
"Person": {
35+
"type": "object",
36+
"properties": {"name": {"type": "string"}},
37+
"required": ["name"],
38+
},
39+
"Team": {
40+
"type": "object",
41+
"properties": {"slug": {"type": "string"}},
42+
"required": ["slug"],
43+
},
44+
}
45+
},
46+
}
47+
48+
49+
def main() -> None:
50+
generate_client(
51+
GenerationRequest(
52+
output_dir=Path(__file__).parent / "generated",
53+
spec_json=json.dumps(SPEC),
54+
overwrite=True,
55+
)
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
from typing import assert_type
4+
5+
from generated.my_client import AsyncClient
6+
from generated.my_client.types import Owner, Person, Team
7+
8+
async_client = AsyncClient(base_url="http://testserver")
9+
10+
11+
async def use_async_client() -> None:
12+
person: Person = {"name": "Ada"}
13+
team: Team = {"slug": "core"}
14+
15+
assert_type(person, Person)
16+
assert_type(team, Team)
17+
assert_type(await async_client.get("/owners")(), Owner)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from __future__ import annotations
2+
3+
from typing import assert_type
4+
5+
from generated.my_client import Client
6+
from generated.my_client.types import Owner, Person, Team
7+
8+
client = Client(base_url="http://testserver")
9+
10+
person: Person = {"name": "Ada"}
11+
team: Team = {"slug": "core"}
12+
13+
assert_type(person, Person)
14+
assert_type(team, Team)
15+
assert_type(client.get("/owners")(), Owner)

tests/test_contracts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def test_contract_types(contract_dir: Path) -> None:
4646
cwd=contract_dir,
4747
)
4848
assert typecheck.returncode == 0, _output(typecheck)
49+
50+
generated_types = _run(
51+
["ty", "check", "generated/my_client/types.py", "--python-version", "3.12"],
52+
cwd=contract_dir,
53+
)
54+
assert generated_types.returncode == 0, _output(generated_types)

0 commit comments

Comments
 (0)