77from pytensor .graph import Apply , FunctionGraph , Op , Type , node_rewriter
88from pytensor .graph .rewriting .basic import in2out
99from pytensor .scalar import constant
10- from pytensor .tensor import (
11- NoneConst ,
12- add ,
13- and_ ,
14- empty ,
15- get_scalar_constant_value ,
16- set_subtensor ,
17- )
10+ from pytensor .tensor import add , and_ , empty , get_scalar_constant_value , set_subtensor
1811from pytensor .tensor .exceptions import NotScalarConstantError
1912from pytensor .tensor .shape import Shape_i
2013from pytensor .tensor .type import DenseTensorType , TensorType
2114from pytensor .tensor .type_other import NoneTypeT
15+ from pytensor .typed_list import TypedListType , append , make_empty_list
2216
2317
2418def validate_loop_update_types (update ):
@@ -176,8 +170,7 @@ def __init__(
176170 )
177171 )
178172 else :
179- # We can't concatenate all types of states, such as RandomTypes
180- self .trace_types .append (NoneConst .type )
173+ self .trace_types .append (TypedListType (state_type ))
181174
182175 self .constant_types = [inp .type for inp in update_fg .inputs [self .n_states :]]
183176 self .n_constants = len (self .constant_types )
@@ -312,10 +305,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312305 if fgraph .clients [trace ]
313306 ]
314307
315- # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316- for trace_idx in used_traces_idxs :
317- assert not isinstance (old_states [trace_idx ].type , NoneTypeT )
318-
319308 # Inputs to the new Loop
320309 max_iters = node .inputs [0 ]
321310 init_states = node .inputs [1 : 1 + op .n_states ]
@@ -324,6 +313,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324313 (max_iters , * tuple (init_states [trace_idx ].shape )),
325314 dtype = init_states [trace_idx ].dtype ,
326315 )
316+ if isinstance (init_states [trace_idx ].type , DenseTensorType )
317+ else make_empty_list (init_states [trace_idx ].type )
327318 for trace_idx in used_traces_idxs
328319 ]
329320 constants = node .inputs [1 + op .n_states :]
@@ -376,6 +367,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
376367 # Inner traces
377368 inner_states = update_fg .inputs [: op .n_states ]
378369 inner_traces = [init_trace .type () for init_trace in init_traces ]
370+
379371 for s , t in zip (inner_states , inner_traces ):
380372 t .name = "trace"
381373 if s .name :
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387379 inner_while_cond , * inner_next_states = update_fg .outputs
388380 inner_next_traces = [
389381 set_subtensor (prev_trace [inner_idx ], inner_next_states [trace_idx ])
382+ if isinstance (prev_trace .type , DenseTensorType )
383+ else append (prev_trace , inner_next_states [trace_idx ])
390384 for trace_idx , prev_trace in zip (used_traces_idxs , inner_traces )
391385 ]
392386 for t in inner_next_traces :
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429423 replacements = dict (zip (old_states , new_states ))
430424 for trace_idx , new_trace in zip (used_traces_idxs , new_traces ):
431425 # If there is no while condition, the whole trace will be used
432- if op .has_while_condition :
426+ if op .has_while_condition and isinstance ( new_trace . type , DenseTensorType ) :
433427 new_trace = new_trace [:final_idx ]
434428 replacements [old_traces [trace_idx ]] = new_trace
435429 return replacements
0 commit comments