Skip to content

Commit 29ce019

Browse files
committed
Use coroutines in hir2ir to simplify the logic and increase recursion limit
Currently, we use a `todo_stack` to store the pending HIR calls. When a user-defined function is called, we push the `hir.Calls` of its body to the `todo_stack` in reverse order. This helps reduce the Python call stack depth but makes it harder to track the state (i.e., which scope does a pending `hir.Call` belong to?). We also need to do some awkward things like pushing special calls to `_set_scope` to reset the scope after the user-defined function call returns, etc. The code would be much more straightforward if we just used recursion. Moreover, this doesn't actually solve the stack depth problem, since nested blocks are processed using recursion anyway. Because of this, we need to keep our recursion limit fairly small (~50 levels deep). We address these problems by using coroutines (aka async functions) and having a partially software stack. Thus we can write the code in a recursive fashion without worrying about the Python's recursion limit. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 10f559c commit 29ce019

File tree

7 files changed

+192
-102
lines changed

7 files changed

+192
-102
lines changed

src/cuda/tile/_coroutine_util.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sys
6+
from contextlib import ExitStack
7+
from dataclasses import dataclass
8+
from typing import Awaitable
9+
10+
11+
# Run a coroutine using a software stack to bypass the Python's recursion limit.
12+
# Use resume_after() to break the call chain and push a new frame to the software stack.
13+
def run_coroutine(awaitable: Awaitable):
14+
ret = None
15+
exc_info = None
16+
stack = []
17+
with ExitStack() as es:
18+
try:
19+
stack.append(awaitable.__await__())
20+
while stack:
21+
top = stack[-1]
22+
try:
23+
continuation = top.send(ret) if exc_info is None else top.throw(*exc_info)
24+
except StopIteration as s:
25+
ret = s.value
26+
exc_info = None
27+
stack.pop()
28+
except Exception:
29+
ret = None
30+
exc_info = sys.exc_info()
31+
stack.pop()
32+
else:
33+
ret = exc_info = None
34+
stack.append(continuation.__await__())
35+
if exc_info is None:
36+
return ret
37+
else:
38+
raise exc_info[1]
39+
finally:
40+
for c in stack:
41+
es.callback(c.close)
42+
43+
44+
# Replace `await foo()` with `await resume_after(foo())` to bypass the recursion limit.
45+
@dataclass
46+
class resume_after:
47+
awaitable: Awaitable
48+
49+
def __await__(self):
50+
return (yield self.awaitable)

src/cuda/tile/_ir/ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,15 @@ def change_loop_info(self, new_info: ControlFlowInfo):
596596
finally:
597597
self.loop_info = old
598598

599+
@contextmanager
600+
def change_scope(self, new_scope: Scope):
601+
old = self.scope
602+
self.scope = new_scope
603+
try:
604+
yield
605+
finally:
606+
self.scope = old
607+
599608
def __enter__(self):
600609
assert not self._entered
601610
self._prev_builder = _current_builder.builder

src/cuda/tile/_ir/op_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def wrapper(*args, **kwargs):
5757
return func(*args, **kwargs)
5858
finally:
5959
_current_stub.stub_and_args = old
60-
60+
wrapper._is_coroutine = inspect.iscoroutinefunction(func)
6161
op_implementations[stub] = wrapper
6262
return orig_func
6363

src/cuda/tile/_ir/ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def format_var(var):
161161

162162

163163
@impl(hir.loop)
164-
def loop_impl(body: hir.Block, iterable: Var):
164+
async def loop_impl(body: hir.Block, iterable: Var):
165165
from .._passes.hir2ir import dispatch_hir_block
166166

167167
range_ty = require_optional_range_type(iterable)
@@ -171,7 +171,7 @@ def loop_impl(body: hir.Block, iterable: Var):
171171
# have a "break" at the end of the loop body, and no other break/continue statements.
172172
info = ControlFlowInfo((), flatten=True)
173173
with Builder.get_current().change_loop_info(info):
174-
dispatch_hir_block(body)
174+
await dispatch_hir_block(body)
175175
return ()
176176

177177
builder = Builder.get_current()
@@ -207,7 +207,7 @@ def loop_impl(body: hir.Block, iterable: Var):
207207
flat_body_vars = flatten_block_parameters(body_vars)
208208

209209
# Dispatch the body (hir.Block) to populate the new_body (ir.Block) with Operations
210-
dispatch_hir_block(body)
210+
await dispatch_hir_block(body)
211211

212212
# Propagate type information from Continue/Break to body/result phis
213213
for jump_info in loop_info.jumps:
@@ -359,11 +359,11 @@ def _to_string_rhs(self) -> str:
359359
return f"if(cond={self.cond})"
360360

361361

362-
def _flatten_branch(branch: hir.Block) -> tuple[Var, ...]:
362+
async def _flatten_branch(branch: hir.Block) -> tuple[Var, ...]:
363363
from .._passes.hir2ir import dispatch_hir_block
364364
info = ControlFlowInfo((), flatten=True)
365365
with Builder.get_current().change_if_else_info(info):
366-
dispatch_hir_block(branch)
366+
await dispatch_hir_block(branch)
367367
if len(info.jumps) == 0:
368368
return ()
369369
else:
@@ -372,13 +372,13 @@ def _flatten_branch(branch: hir.Block) -> tuple[Var, ...]:
372372

373373

374374
@impl(hir.if_else)
375-
def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block) -> tuple[Var, ...]:
375+
async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block) -> tuple[Var, ...]:
376376
from .._passes.hir2ir import dispatch_hir_block
377377

378378
require_bool(cond)
379379
if cond.is_constant():
380380
branch_taken = then_block if cond.get_constant() else else_block
381-
return _flatten_branch(branch_taken)
381+
return await _flatten_branch(branch_taken)
382382

383383
# Get the total number of results by adding the number of stored variables.
384384
# Note: we sort the stored variable names to make the order deterministic.
@@ -387,7 +387,7 @@ def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block) -> tup
387387
# Convert the "then" branch from HIR to IR
388388
info = ControlFlowInfo(stored_names)
389389
with nested_block(then_block.name, then_block.loc, if_else_info=info) as new_then_block:
390-
dispatch_hir_block(then_block)
390+
await dispatch_hir_block(then_block)
391391

392392
# If "then" branch doesn't yield, transform our if-else into the following:
393393
# if cond:
@@ -402,11 +402,11 @@ def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block) -> tup
402402
end_branch(())
403403
add_operation(IfElse, (),
404404
cond=cond, then_block=new_then_block, else_block=new_else_block)
405-
return _flatten_branch(else_block)
405+
return await _flatten_branch(else_block)
406406

407407
# Convert the "else" branch from HIR to IR
408408
with nested_block(else_block.name, else_block.loc, if_else_info=info) as new_else_block:
409-
dispatch_hir_block(else_block)
409+
await dispatch_hir_block(else_block)
410410

411411
# Do type/constant propagation
412412
num_results = len(info.jumps[0].outputs)

src/cuda/tile/_passes/hir2ir.py

Lines changed: 57 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,52 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
54
import inspect
65
import sys
76
from contextlib import contextmanager
8-
from dataclasses import dataclass
9-
from typing import Any, Sequence
7+
from typing import Any
108

119
from .ast2hir import get_function_hir
1210
from .. import TileTypeError
11+
from .._coroutine_util import resume_after, run_coroutine
1312
from .._exception import Loc, TileSyntaxError, TileInternalError, TileError, TileRecursionError
1413
from .._ir import hir, ir
1514
from .._ir.ir import Var, IRContext, Argument, Scope, LocalScope, BoundMethodValue
16-
from .._ir.op_impl import op_implementations, impl
15+
from .._ir.op_impl import op_implementations
1716
from .._ir.ops import loosely_typed_const, assign, end_branch, return_, continue_, \
18-
break_, flatten_block_parameters
17+
break_, flatten_block_parameters, store_var
1918
from .._ir.type import FunctionTy, BoundMethodTy, DTypeConstructor
2019
from .._ir.typing_support import get_signature
2120

2221

23-
MAX_RECURSION_DEPTH = 50
22+
MAX_RECURSION_DEPTH = 1000
2423

2524

2625
def hir2ir(func_hir: hir.Function,
2726
args: tuple[Argument, ...],
2827
ir_ctx: IRContext) -> ir.Block:
28+
# Run as a coroutine using a software stack, so that we don't exceed Python's recursion limit.
29+
return run_coroutine(_hir2ir_coroutine(func_hir, args, ir_ctx))
30+
31+
32+
async def _hir2ir_coroutine(func_hir: hir.Function, args: tuple[Argument, ...], ir_ctx: IRContext):
2933
scope = _create_scope(func_hir, ir_ctx, call_site=None)
3034
aggregate_params = [
3135
scope.local.redefine(param_name, param_loc)
3236
for param_name, param_loc in zip(func_hir.param_names, func_hir.param_locs, strict=True)
3337
]
34-
preamble = []
35-
for param_name, var, arg in zip(func_hir.param_names, aggregate_params, args, strict=True):
36-
var.set_type(arg.type)
37-
if arg.is_const:
38-
preamble.append(hir.Call((), hir.store_var, (param_name, arg.const_value), (), var.loc))
3938

4039
with ir.Builder(ir_ctx, func_hir.body.loc, scope) as ir_builder:
41-
flat_params = flatten_block_parameters(aggregate_params)
4240
try:
43-
_dispatch_hir_block_inner(preamble, func_hir.body, ir_builder)
41+
for param_name, var, arg in zip(func_hir.param_names, aggregate_params, args,
42+
strict=True):
43+
var.set_type(arg.type)
44+
if arg.is_const:
45+
var = loosely_typed_const(arg.const_value)
46+
store_var(param_name, var, var.loc)
47+
flat_params = flatten_block_parameters(aggregate_params)
48+
49+
await _dispatch_hir_block_inner(func_hir.body, ir_builder)
4450
except Exception as e:
4551
if 'CUTILEIR' in ir_ctx.tile_ctx.config.log_keys:
4652
highlight_loc = e.loc if hasattr(e, 'loc') else None
@@ -60,54 +66,41 @@ def _create_scope(func_hir: hir.Function, ir_ctx: IRContext, call_site: Loc | No
6066
return Scope(local_scope, func_hir.frozen_globals, call_site)
6167

6268

63-
def dispatch_hir_block(block: hir.Block):
64-
_dispatch_hir_block_inner((), block, ir.Builder.get_current())
65-
66-
67-
@dataclass
68-
class _State:
69-
done: list[hir.Call]
70-
current: hir.Call | None
71-
todo_stack: list[hir.Call]
69+
async def dispatch_hir_block(block: hir.Block, cur_builder: ir.Builder | None = None):
70+
if cur_builder is None:
71+
cur_builder = ir.Builder.get_current()
72+
await _dispatch_hir_block_inner(block, cur_builder)
7273

73-
@contextmanager
74-
def next_call(self):
75-
call = self.current = self.todo_stack.pop()
76-
yield call
77-
# Intentionally not in a "finally" block because we want to preserve the state
78-
# for the debug printout in case of an exception.
79-
self.current = None
80-
self.done.append(call)
8174

82-
83-
def _dispatch_hir_block_inner(preamble: Sequence[hir.Call],
84-
block: hir.Block,
85-
builder: ir.Builder):
86-
state = _State([], None, list(reversed(block.calls)) + list(reversed(preamble)))
75+
async def _dispatch_hir_block_inner(block: hir.Block, builder: ir.Builder):
76+
cursor = 0 # Pre-initialize to guarantee it's defined in the `except` block
8777
try:
88-
if not _dispatch_hir_calls(state, builder):
89-
return ()
78+
for cursor, call in enumerate(block.calls):
79+
loc = _add_call_site(call.loc, builder)
80+
with _wrap_exceptions(loc), builder.change_loc(loc):
81+
await _dispatch_call(call, builder)
82+
if builder.is_terminated:
83+
# The current block has been terminated, e.g. by flattening an if-else
84+
# with a constant condition (`if True: break`).
85+
return
86+
cursor = len(block.calls)
87+
9088
result_vars = tuple(_resolve_operand(x) for x in block.results)
9189
loc = _add_call_site(block.jump_loc, builder)
9290
with _wrap_exceptions(loc), builder.change_loc(loc):
9391
_dispatch_hir_jump(block.jump, result_vars)
9492
except Exception:
9593
if 'CUTILEIR' in builder.ir_ctx.tile_ctx.config.log_keys:
9694
hir_params = ", ".join(p.name for p in block.params)
97-
hir_lines = [str(c) for c in state.done]
98-
cur_idx = len(hir_lines)
99-
if state.current is not None:
100-
hir_lines.append(str(state.current))
101-
hir_lines.extend(str(c) for c in reversed(state.todo_stack))
95+
hir_lines = [str(c) for c in block.calls]
10296
hir_lines.append(block.jump_str())
103-
hir_str = "\n".join("{}{}".format("--> " if i == cur_idx else " ", c)
97+
hir_str = "\n".join("{}{}".format("--> " if i == cursor else " ", c)
10498
for i, c in enumerate(hir_lines))
10599
print(f"==== HIR for ^{block.name}({hir_params}) ====\n{hir_str}\n", file=sys.stderr)
106100
raise
107101

108102

109-
def _dispatch_hir_jump(jump: hir.Jump,
110-
block_results: tuple[Var, ...]):
103+
def _dispatch_hir_jump(jump: hir.Jump | None, block_results: tuple[Var, ...]):
111104
match jump:
112105
case hir.Jump.END_BRANCH:
113106
end_branch(block_results)
@@ -120,23 +113,10 @@ def _dispatch_hir_jump(jump: hir.Jump,
120113
case hir.Jump.RETURN:
121114
assert len(block_results) == 1
122115
return_(block_results[0])
116+
case None: pass
123117
case _: assert False
124118

125119

126-
def _dispatch_hir_calls(state: _State, cur_builder: ir.Builder) -> bool:
127-
while len(state.todo_stack) > 0:
128-
with state.next_call() as call:
129-
loc = _add_call_site(call.loc, cur_builder)
130-
with _wrap_exceptions(loc), cur_builder.change_loc(loc):
131-
_dispatch_call(call, cur_builder, state.todo_stack)
132-
if cur_builder.is_terminated:
133-
# The current block has been terminated, e.g. by flattening an if-else
134-
# with a constant condition (`if True: break`). By returning False,
135-
# we signal that the original jump and block results should be ignored.
136-
return False
137-
return True
138-
139-
140120
def _add_call_site(loc: Loc, builder: ir.Builder) -> Loc:
141121
return loc.with_call_site(builder.scope.call_site)
142122

@@ -152,7 +132,7 @@ def _wrap_exceptions(loc: Loc):
152132
raise TileInternalError(str(e)) from e
153133

154134

155-
def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Call]):
135+
async def _dispatch_call(call: hir.Call, builder: ir.Builder):
156136
first_idx = len(builder.ops)
157137
callee_var = _resolve_operand(call.callee)
158138
callee, self_arg = _get_callee_and_self(callee_var)
@@ -161,7 +141,11 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
161141
arg_list = _bind_args(callee, args, kwargs)
162142

163143
if callee in op_implementations:
164-
result = op_implementations[callee](*arg_list)
144+
impl = op_implementations[callee]
145+
result = impl(*arg_list)
146+
if impl._is_coroutine:
147+
result = await result
148+
165149
if builder.is_terminated:
166150
# The current block has been terminated, e.g. by flattening an if-else
167151
# with a constant condition (`if True: break`). Ignore the `result` in this case.
@@ -203,25 +187,20 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
203187
" functions are not supported")
204188
callee_hir = get_function_hir(callee, builder.ir_ctx, entry_point=False)
205189

206-
# Since `todo_stack` is a stack, we push things backwards. First, we push identity()
207-
# calls to assign the temporary return values back to the original result variables.
208-
for callee_retval, caller_res in zip(callee_hir.body.results, call.results):
209-
todo_stack.append(hir.Call((caller_res,), hir.identity, (callee_retval,), (), call.loc))
190+
# Activate a fresh Scope.
191+
new_scope = _create_scope(callee_hir, builder.ir_ctx, call_site=builder.loc)
192+
with builder.change_scope(new_scope):
193+
# Call store_var() to bind arguments to parameters.
194+
for arg, param_name, param_loc in zip(arg_list, callee_hir.param_names,
195+
callee_hir.param_locs, strict=True):
196+
store_var(param_name, arg, param_loc)
210197

211-
# Now we create a fresh Scope for the new function and install it on the builder.
212-
# We need to reset the builder back to the old scope when we return.
213-
# For this purpose, we push a call to the special _set_scope stub.
214-
old_scope = builder.scope
215-
todo_stack.append(hir.Call((), _set_scope, (old_scope,), (), call.loc))
216-
builder.scope = _create_scope(callee_hir, builder.ir_ctx, call_site=builder.loc)
198+
# Dispatch the function body. Use resume_after() to break the call stack
199+
# and make sure we stay within the Python's recursion limit.
200+
await resume_after(dispatch_hir_block(callee_hir.body, builder))
217201

218-
# Now push the function body.
219-
todo_stack.extend(reversed(callee_hir.body.calls))
220-
221-
# Finally, call store_var() to bind arguments to parameters.
222-
for arg, param_name, param_loc in zip(arg_list, callee_hir.param_names,
223-
callee_hir.param_locs, strict=True):
224-
todo_stack.append(hir.Call((), hir.store_var, (param_name, arg), (), param_loc))
202+
for callee_retval, caller_res in zip(callee_hir.body.results, call.results):
203+
assign(callee_retval, caller_res)
225204

226205

227206
def _is_freshly_defined(var: Var, builder: ir.Builder, first_idx: int):
@@ -277,13 +256,3 @@ def _bind_args(sig_func, args, kwargs) -> list[Var]:
277256
assert param.default is not param.empty
278257
ret.append(loosely_typed_const(param.default))
279258
return ret
280-
281-
282-
def _set_scope(scope): ...
283-
284-
285-
@impl(_set_scope)
286-
def _set_scope_impl(scope):
287-
assert isinstance(scope, Scope)
288-
builder = ir.Builder.get_current()
289-
builder.scope = scope

0 commit comments

Comments
 (0)