11# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# SPDX-License-Identifier: Apache-2.0
4-
54import inspect
65import sys
76from contextlib import contextmanager
8- from dataclasses import dataclass
9- from typing import Any , Sequence
7+ from typing import Any
108
119from .ast2hir import get_function_hir
1210from .. import TileTypeError
11+ from .._coroutine_util import resume_after , run_coroutine
1312from .._exception import Loc , TileSyntaxError , TileInternalError , TileError , TileRecursionError
1413from .._ir import hir , ir
1514from .._ir .ir import Var , IRContext , Argument , Scope , LocalScope , BoundMethodValue
16- from .._ir .op_impl import op_implementations , impl
15+ from .._ir .op_impl import op_implementations
1716from .._ir .ops import loosely_typed_const , assign , end_branch , return_ , continue_ , \
18- break_ , flatten_block_parameters
17+ break_ , flatten_block_parameters , store_var
1918from .._ir .type import FunctionTy , BoundMethodTy , DTypeConstructor
2019from .._ir .typing_support import get_signature
2120
2221
23- MAX_RECURSION_DEPTH = 50
22+ MAX_RECURSION_DEPTH = 1000
2423
2524
2625def hir2ir (func_hir : hir .Function ,
2726 args : tuple [Argument , ...],
2827 ir_ctx : IRContext ) -> ir .Block :
28+ # Run as a coroutine using a software stack, so that we don't exceed Python's recursion limit.
29+ return run_coroutine (_hir2ir_coroutine (func_hir , args , ir_ctx ))
30+
31+
32+ async def _hir2ir_coroutine (func_hir : hir .Function , args : tuple [Argument , ...], ir_ctx : IRContext ):
2933 scope = _create_scope (func_hir , ir_ctx , call_site = None )
3034 aggregate_params = [
3135 scope .local .redefine (param_name , param_loc )
3236 for param_name , param_loc in zip (func_hir .param_names , func_hir .param_locs , strict = True )
3337 ]
34- preamble = []
35- for param_name , var , arg in zip (func_hir .param_names , aggregate_params , args , strict = True ):
36- var .set_type (arg .type )
37- if arg .is_const :
38- preamble .append (hir .Call ((), hir .store_var , (param_name , arg .const_value ), (), var .loc ))
3938
4039 with ir .Builder (ir_ctx , func_hir .body .loc , scope ) as ir_builder :
41- flat_params = flatten_block_parameters (aggregate_params )
4240 try :
43- _dispatch_hir_block_inner (preamble , func_hir .body , ir_builder )
41+ for param_name , var , arg in zip (func_hir .param_names , aggregate_params , args ,
42+ strict = True ):
43+ var .set_type (arg .type )
44+ if arg .is_const :
45+ var = loosely_typed_const (arg .const_value )
46+ store_var (param_name , var , var .loc )
47+ flat_params = flatten_block_parameters (aggregate_params )
48+
49+ await _dispatch_hir_block_inner (func_hir .body , ir_builder )
4450 except Exception as e :
4551 if 'CUTILEIR' in ir_ctx .tile_ctx .config .log_keys :
4652 highlight_loc = e .loc if hasattr (e , 'loc' ) else None
@@ -60,54 +66,41 @@ def _create_scope(func_hir: hir.Function, ir_ctx: IRContext, call_site: Loc | No
6066 return Scope (local_scope , func_hir .frozen_globals , call_site )
6167
6268
63- def dispatch_hir_block (block : hir .Block ):
64- _dispatch_hir_block_inner ((), block , ir .Builder .get_current ())
65-
66-
67- @dataclass
68- class _State :
69- done : list [hir .Call ]
70- current : hir .Call | None
71- todo_stack : list [hir .Call ]
69+ async def dispatch_hir_block (block : hir .Block , cur_builder : ir .Builder | None = None ):
70+ if cur_builder is None :
71+ cur_builder = ir .Builder .get_current ()
72+ await _dispatch_hir_block_inner (block , cur_builder )
7273
73- @contextmanager
74- def next_call (self ):
75- call = self .current = self .todo_stack .pop ()
76- yield call
77- # Intentionally not in a "finally" block because we want to preserve the state
78- # for the debug printout in case of an exception.
79- self .current = None
80- self .done .append (call )
8174
82-
83- def _dispatch_hir_block_inner (preamble : Sequence [hir .Call ],
84- block : hir .Block ,
85- builder : ir .Builder ):
86- state = _State ([], None , list (reversed (block .calls )) + list (reversed (preamble )))
75+ async def _dispatch_hir_block_inner (block : hir .Block , builder : ir .Builder ):
76+ cursor = 0 # Pre-initialize to guarantee it's defined in the `except` block
8777 try :
88- if not _dispatch_hir_calls (state , builder ):
89- return ()
78+ for cursor , call in enumerate (block .calls ):
79+ loc = _add_call_site (call .loc , builder )
80+ with _wrap_exceptions (loc ), builder .change_loc (loc ):
81+ await _dispatch_call (call , builder )
82+ if builder .is_terminated :
83+ # The current block has been terminated, e.g. by flattening an if-else
84+ # with a constant condition (`if True: break`).
85+ return
86+ cursor = len (block .calls )
87+
9088 result_vars = tuple (_resolve_operand (x ) for x in block .results )
9189 loc = _add_call_site (block .jump_loc , builder )
9290 with _wrap_exceptions (loc ), builder .change_loc (loc ):
9391 _dispatch_hir_jump (block .jump , result_vars )
9492 except Exception :
9593 if 'CUTILEIR' in builder .ir_ctx .tile_ctx .config .log_keys :
9694 hir_params = ", " .join (p .name for p in block .params )
97- hir_lines = [str (c ) for c in state .done ]
98- cur_idx = len (hir_lines )
99- if state .current is not None :
100- hir_lines .append (str (state .current ))
101- hir_lines .extend (str (c ) for c in reversed (state .todo_stack ))
95+ hir_lines = [str (c ) for c in block .calls ]
10296 hir_lines .append (block .jump_str ())
103- hir_str = "\n " .join ("{}{}" .format ("--> " if i == cur_idx else " " , c )
97+ hir_str = "\n " .join ("{}{}" .format ("--> " if i == cursor else " " , c )
10498 for i , c in enumerate (hir_lines ))
10599 print (f"==== HIR for ^{ block .name } ({ hir_params } ) ====\n { hir_str } \n " , file = sys .stderr )
106100 raise
107101
108102
109- def _dispatch_hir_jump (jump : hir .Jump ,
110- block_results : tuple [Var , ...]):
103+ def _dispatch_hir_jump (jump : hir .Jump | None , block_results : tuple [Var , ...]):
111104 match jump :
112105 case hir .Jump .END_BRANCH :
113106 end_branch (block_results )
@@ -120,23 +113,10 @@ def _dispatch_hir_jump(jump: hir.Jump,
120113 case hir .Jump .RETURN :
121114 assert len (block_results ) == 1
122115 return_ (block_results [0 ])
116+ case None : pass
123117 case _: assert False
124118
125119
126- def _dispatch_hir_calls (state : _State , cur_builder : ir .Builder ) -> bool :
127- while len (state .todo_stack ) > 0 :
128- with state .next_call () as call :
129- loc = _add_call_site (call .loc , cur_builder )
130- with _wrap_exceptions (loc ), cur_builder .change_loc (loc ):
131- _dispatch_call (call , cur_builder , state .todo_stack )
132- if cur_builder .is_terminated :
133- # The current block has been terminated, e.g. by flattening an if-else
134- # with a constant condition (`if True: break`). By returning False,
135- # we signal that the original jump and block results should be ignored.
136- return False
137- return True
138-
139-
140120def _add_call_site (loc : Loc , builder : ir .Builder ) -> Loc :
141121 return loc .with_call_site (builder .scope .call_site )
142122
@@ -152,7 +132,7 @@ def _wrap_exceptions(loc: Loc):
152132 raise TileInternalError (str (e )) from e
153133
154134
155- def _dispatch_call (call : hir .Call , builder : ir .Builder , todo_stack : list [ hir . Call ] ):
135+ async def _dispatch_call (call : hir .Call , builder : ir .Builder ):
156136 first_idx = len (builder .ops )
157137 callee_var = _resolve_operand (call .callee )
158138 callee , self_arg = _get_callee_and_self (callee_var )
@@ -161,7 +141,11 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
161141 arg_list = _bind_args (callee , args , kwargs )
162142
163143 if callee in op_implementations :
164- result = op_implementations [callee ](* arg_list )
144+ impl = op_implementations [callee ]
145+ result = impl (* arg_list )
146+ if impl ._is_coroutine :
147+ result = await result
148+
165149 if builder .is_terminated :
166150 # The current block has been terminated, e.g. by flattening an if-else
167151 # with a constant condition (`if True: break`). Ignore the `result` in this case.
@@ -203,25 +187,20 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
203187 " functions are not supported" )
204188 callee_hir = get_function_hir (callee , builder .ir_ctx , entry_point = False )
205189
206- # Since `todo_stack` is a stack, we push things backwards. First, we push identity()
207- # calls to assign the temporary return values back to the original result variables.
208- for callee_retval , caller_res in zip (callee_hir .body .results , call .results ):
209- todo_stack .append (hir .Call ((caller_res ,), hir .identity , (callee_retval ,), (), call .loc ))
190+ # Activate a fresh Scope.
191+ new_scope = _create_scope (callee_hir , builder .ir_ctx , call_site = builder .loc )
192+ with builder .change_scope (new_scope ):
193+ # Call store_var() to bind arguments to parameters.
194+ for arg , param_name , param_loc in zip (arg_list , callee_hir .param_names ,
195+ callee_hir .param_locs , strict = True ):
196+ store_var (param_name , arg , param_loc )
210197
211- # Now we create a fresh Scope for the new function and install it on the builder.
212- # We need to reset the builder back to the old scope when we return.
213- # For this purpose, we push a call to the special _set_scope stub.
214- old_scope = builder .scope
215- todo_stack .append (hir .Call ((), _set_scope , (old_scope ,), (), call .loc ))
216- builder .scope = _create_scope (callee_hir , builder .ir_ctx , call_site = builder .loc )
198+ # Dispatch the function body. Use resume_after() to break the call stack
199+ # and make sure we stay within the Python's recursion limit.
200+ await resume_after (dispatch_hir_block (callee_hir .body , builder ))
217201
218- # Now push the function body.
219- todo_stack .extend (reversed (callee_hir .body .calls ))
220-
221- # Finally, call store_var() to bind arguments to parameters.
222- for arg , param_name , param_loc in zip (arg_list , callee_hir .param_names ,
223- callee_hir .param_locs , strict = True ):
224- todo_stack .append (hir .Call ((), hir .store_var , (param_name , arg ), (), param_loc ))
202+ for callee_retval , caller_res in zip (callee_hir .body .results , call .results ):
203+ assign (callee_retval , caller_res )
225204
226205
227206def _is_freshly_defined (var : Var , builder : ir .Builder , first_idx : int ):
@@ -277,13 +256,3 @@ def _bind_args(sig_func, args, kwargs) -> list[Var]:
277256 assert param .default is not param .empty
278257 ret .append (loosely_typed_const (param .default ))
279258 return ret
280-
281-
282- def _set_scope (scope ): ...
283-
284-
285- @impl (_set_scope )
286- def _set_scope_impl (scope ):
287- assert isinstance (scope , Scope )
288- builder = ir .Builder .get_current ()
289- builder .scope = scope
0 commit comments