Skip to content

Commit 37cd9a2

Browse files
author
Anders Brams
committed
feat: reduced generated file size ~50%
1 parent 4d0040b commit 37cd9a2

7 files changed

Lines changed: 182 additions & 55 deletions

File tree

.github/workflows/qa.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,21 @@ jobs:
8282
uv run python scripts/benchmark_generate.py compare --baseline
8383
.benchmark/base.json --candidate .benchmark/current.json
8484
--max-regression 0.02
85+
86+
- name: Measure generated file sizes (base)
87+
run: >
88+
uv run python scripts/benchmark_generate.py size --package-path
89+
.benchmark/base --spec tests/performance/nautobot.json.gz
90+
--output .benchmark/base-size.json
91+
92+
- name: Measure generated file sizes (current)
93+
run: >
94+
uv run python scripts/benchmark_generate.py size --package-path .
95+
--spec tests/performance/nautobot.json.gz
96+
--output .benchmark/current-size.json
97+
98+
- name: Check generated file size regression
99+
run: >
100+
uv run python scripts/benchmark_generate.py compare-size --baseline
101+
.benchmark/base-size.json --candidate .benchmark/current-size.json
102+
--max-regression 0.02

openapi_python/generator/render.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,22 @@ def _field_annotation(field: FieldDef) -> str:
5454
return annotation
5555

5656

57+
def _class_field_annotation(field: FieldDef, total_optional: bool) -> str:
58+
annotation = _render_annotation(field.annotation)
59+
if not field.required and not total_optional:
60+
annotation = f"NotRequired[{annotation}]"
61+
return annotation
62+
63+
64+
def _supports_typeddict_class_syntax(defn: TypedDictDef) -> bool:
65+
return all(
66+
field.name.isidentifier()
67+
and not keyword.iskeyword(field.name)
68+
and not field.name.startswith("__")
69+
for field in defn.fields
70+
)
71+
72+
5773
_TEMPLATE_DIR = Path(__file__).with_name("templates")
5874
_JINJA_ENV = Environment(
5975
loader=FileSystemLoader(_TEMPLATE_DIR),
@@ -64,6 +80,7 @@ def _field_annotation(field: FieldDef) -> str:
6480
_JINJA_ENV.filters["repr"] = repr
6581
_JINJA_ENV.filters["annotation"] = _render_annotation
6682
_JINJA_ENV.filters["field_annotation"] = _field_annotation
83+
_JINJA_ENV.filters["class_field_annotation"] = _class_field_annotation
6784

6885

6986
def _render_template(name: str, **context: object) -> str:
@@ -76,7 +93,15 @@ def _indent(text: str, spaces: int = 4) -> str:
7693

7794

7895
def _format_typeddict(defn: TypedDictDef) -> str:
79-
return _render_template("typeddict.py.j2", defn=defn)
96+
total_optional = bool(defn.fields) and all(
97+
not field.required for field in defn.fields
98+
)
99+
return _render_template(
100+
"typeddict.py.j2",
101+
defn=defn,
102+
class_syntax=_supports_typeddict_class_syntax(defn),
103+
total_optional=total_optional,
104+
)
80105

81106

82107
def _format_alias(alias: TypeAliasDef) -> str:
@@ -249,22 +274,13 @@ def _method_overload_line(op: OperationDef, *, is_async: bool = False) -> str:
249274
)
250275

251276

252-
def _method_dispatch_line(op: OperationDef, *, is_async: bool = False) -> str:
253-
return _render_template(
254-
"method_dispatch.py.j2",
255-
op=op,
256-
is_async=is_async,
257-
)
258-
259-
260277
def _fallback_method_block(
261-
method: str, overloads: list[str], dispatch: list[str], *, is_async: bool = False
278+
method: str, overloads: list[str], *, is_async: bool = False
262279
) -> str:
263280
return _render_template(
264281
"method_block.py.j2",
265282
method=method,
266283
overloads="\n".join(overloads),
267-
dispatch_block="\n\n ".join(dispatch),
268284
callable_return="Awaitable[Any]" if is_async else "object",
269285
call_return="Any" if is_async else "object",
270286
is_async=is_async,
@@ -330,27 +346,20 @@ def _render_client(spec: NormalizedSpec, *, transport_mode: str) -> str:
330346
async_protocols: list[str] = []
331347
method_overloads: dict[str, list[str]] = {}
332348
async_method_overloads: dict[str, list[str]] = {}
333-
method_dispatch: dict[str, list[str]] = {}
334-
async_method_dispatch: dict[str, list[str]] = {}
335349
for op in spec.operations:
336350
protocols.append(_protocol_block(op))
337351
async_protocols.append(_protocol_block(op, is_async=True))
338352
method_overloads.setdefault(op.method, []).append(_method_overload_line(op))
339353
async_method_overloads.setdefault(op.method, []).append(
340354
_method_overload_line(op, is_async=True)
341355
)
342-
method_dispatch.setdefault(op.method, []).append(_method_dispatch_line(op))
343-
async_method_dispatch.setdefault(op.method, []).append(
344-
_method_dispatch_line(op, is_async=True)
345-
)
346356

347357
method_blocks: list[str] = []
348358
for method in sorted(method_overloads):
349359
method_blocks.append(
350360
_fallback_method_block(
351361
method,
352362
method_overloads[method],
353-
method_dispatch.get(method, []),
354363
)
355364
)
356365
async_method_blocks: list[str] = []
@@ -359,7 +368,6 @@ def _render_client(spec: NormalizedSpec, *, transport_mode: str) -> str:
359368
_fallback_method_block(
360369
method,
361370
async_method_overloads[method],
362-
async_method_dispatch.get(method, []),
363371
is_async=True,
364372
)
365373
)

openapi_python/generator/templates/method_block.py.j2

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
@overload
33
def {{ method }}(self, route: str) -> Callable[..., {{ callable_return }}]: ...
44
def {{ method }}(self, route: str) -> Callable[..., {{ callable_return }}]:
5-
{{ dispatch_block }}
6-
{{ "async " if is_async else "" }}def _call(*, params: dict[str, object] | None = None, query: dict[str, object] | None = None, headers: dict[str, object] | None = None, body: object | None = None) -> {{ call_return }}:
7-
return {{ "await " if is_async else "" }}self._transport.request(
8-
method={{ method|repr }},
9-
route=route,
10-
base_url=self._base_url,
11-
params=params,
12-
query=query,
13-
headers=headers,
14-
body=body,
15-
)
16-
return _call
5+
{{ "async " if is_async else "" }}def _call(*, params: dict[str, object] | None = None, query: dict[str, object] | None = None, headers: dict[str, object] | None = None, body: object | None = None) -> {{ call_return }}:
6+
return {{ "await " if is_async else "" }}self._transport.request(
7+
method={{ method|repr }},
8+
route=route,
9+
base_url=self._base_url,
10+
params=params,
11+
query=query,
12+
headers=headers,
13+
body=body,
14+
)
15+
return _call

openapi_python/generator/templates/typeddict.py.j2

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
{% if not defn.fields -%}
1+
{% if class_syntax -%}
2+
class {{ defn.name }}(TypedDict{% if total_optional %}, total=False{% endif %}):
3+
{% if not defn.fields %}
4+
pass
5+
{% else %}
6+
{% for field in defn.fields %}
7+
{{ field.name }}: {{ field | class_field_annotation(total_optional) }}
8+
{% endfor %}
9+
{% endif %}
10+
{% elif not defn.fields -%}
211
{{ defn.name }} = TypedDict({{ defn.name | repr }}, {})
312
{% else -%}
413
{{ defn.name }} = TypedDict(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ dev = [
4242
"python-multipart>=0.0.20",
4343
"ruff>=0.9.10",
4444
"twine>=6.1.0",
45-
"ty>=0.0.34",
45+
"ty>=0.0.37",
4646
"uvicorn>=0.34.0",
4747
]
4848

scripts/benchmark_generate.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,32 @@ def _run_once(
5050
shutil.rmtree(output_dir, ignore_errors=True)
5151

5252

53+
def _run_size_once(
54+
*,
55+
generate_client: Any,
56+
generation_request: Any,
57+
spec_json: str,
58+
package_name: str,
59+
) -> tuple[dict[str, int], Any]:
60+
output_dir = Path(tempfile.mkdtemp(prefix="openapi-python-size-"))
61+
try:
62+
result = generate_client(
63+
generation_request(
64+
output_dir=output_dir,
65+
spec_json=spec_json,
66+
package_name=package_name,
67+
overwrite=True,
68+
)
69+
)
70+
file_sizes = {
71+
path.relative_to(output_dir).as_posix(): path.stat().st_size
72+
for path in result.written_files
73+
}
74+
return file_sizes, result
75+
finally:
76+
shutil.rmtree(output_dir, ignore_errors=True)
77+
78+
5379
def run_benchmark(args: argparse.Namespace) -> int:
5480
spec_json = _load_spec(args.spec)
5581
generation_request, generate_client = _load_generator(args.package_path)
@@ -93,6 +119,35 @@ def run_benchmark(args: argparse.Namespace) -> int:
93119
return 0
94120

95121

122+
def run_size_benchmark(args: argparse.Namespace) -> int:
123+
spec_json = _load_spec(args.spec)
124+
generation_request, generate_client = _load_generator(args.package_path)
125+
126+
file_sizes, result = _run_size_once(
127+
generate_client=generate_client,
128+
generation_request=generation_request,
129+
spec_json=spec_json,
130+
package_name=args.package,
131+
)
132+
133+
payload = {
134+
"files_bytes": file_sizes,
135+
"operations": result.operations,
136+
"total_bytes": sum(file_sizes.values()),
137+
"type_definitions": result.type_definitions,
138+
}
139+
140+
encoded = json.dumps(payload, indent=2, sort_keys=True)
141+
if args.output:
142+
args.output.write_text(encoded + "\n", encoding="utf-8")
143+
print(encoded)
144+
return 0
145+
146+
147+
def _format_bytes(value: float) -> str:
148+
return f"{value:,.0f} bytes"
149+
150+
96151
def compare_benchmarks(args: argparse.Namespace) -> int:
97152
baseline = json.loads(args.baseline.read_text(encoding="utf-8"))
98153
candidate = json.loads(args.candidate.read_text(encoding="utf-8"))
@@ -116,6 +171,29 @@ def compare_benchmarks(args: argparse.Namespace) -> int:
116171
return 0
117172

118173

174+
def compare_size_benchmarks(args: argparse.Namespace) -> int:
175+
baseline = json.loads(args.baseline.read_text(encoding="utf-8"))
176+
candidate = json.loads(args.candidate.read_text(encoding="utf-8"))
177+
178+
baseline_bytes = float(baseline["total_bytes"])
179+
candidate_bytes = float(candidate["total_bytes"])
180+
allowed_bytes = baseline_bytes * (1 + args.max_regression)
181+
change = (candidate_bytes - baseline_bytes) / baseline_bytes
182+
183+
print(f"baseline total: {_format_bytes(baseline_bytes)}")
184+
print(f"candidate total: {_format_bytes(candidate_bytes)}")
185+
print(f"change: {change:+.2%}")
186+
print(f"limit: +{args.max_regression:.2%}")
187+
188+
if candidate_bytes > allowed_bytes:
189+
print(
190+
"generated file size regressed beyond the configured limit",
191+
file=sys.stderr,
192+
)
193+
return 1
194+
return 0
195+
196+
119197
def _build_parser() -> argparse.ArgumentParser:
120198
parser = argparse.ArgumentParser(
121199
prog="benchmark_generate.py",
@@ -132,12 +210,27 @@ def _build_parser() -> argparse.ArgumentParser:
132210
run.add_argument("--output", type=Path)
133211
run.set_defaults(func=run_benchmark)
134212

213+
size = subcommands.add_parser("size", help="Measure generated file sizes")
214+
size.add_argument("--spec", type=Path, required=True)
215+
size.add_argument("--package-path", type=Path, default=Path.cwd())
216+
size.add_argument("--package", default="my_client")
217+
size.add_argument("--output", type=Path)
218+
size.set_defaults(func=run_size_benchmark)
219+
135220
compare = subcommands.add_parser("compare", help="Compare two benchmark results")
136221
compare.add_argument("--baseline", type=Path, required=True)
137222
compare.add_argument("--candidate", type=Path, required=True)
138223
compare.add_argument("--max-regression", type=float, default=0.02)
139224
compare.set_defaults(func=compare_benchmarks)
140225

226+
compare_size = subcommands.add_parser(
227+
"compare-size", help="Compare two generated file size results"
228+
)
229+
compare_size.add_argument("--baseline", type=Path, required=True)
230+
compare_size.add_argument("--candidate", type=Path, required=True)
231+
compare_size.add_argument("--max-regression", type=float, default=0.02)
232+
compare_size.set_defaults(func=compare_size_benchmarks)
233+
141234
return parser
142235

143236

0 commit comments

Comments
 (0)