@@ -92,38 +92,29 @@ def construct_nominal_fgraph(
9292 dict [Variable , Variable ],
9393]:
9494 """Construct an inner-`FunctionGraph` with ordered nominal inputs."""
95- dummy_inputs = []
96- for n , inp in enumerate (inputs ):
97- if (
98- not isinstance (inp , Variable )
99- or isinstance (inp , Constant )
100- or isinstance (inp , SharedVariable )
101- ):
102- raise TypeError (
103- f"Inputs and outputs must be non-Constant/shared Variable instances; got { inp } "
104- )
105-
106- dummy_inputs .append (inp .type ())
95+ implicit_shared_inputs = []
10796
108- dummy_shared_inputs = []
109- shared_inputs = []
97+ dummy_inputs = [inp . type () for inp in inputs ]
98+ dummy_implicit_shared_inputs = []
11099 for var in graph_inputs (outputs , inputs ):
100+ if var in inputs :
101+ continue
111102 if isinstance (var , SharedVariable ):
112- # To correctly support shared variables the inner- graph should
113- # not see them; otherwise, there will be problems with
114- # gradients.
115- # That's why we collect the shared variables and replace them
116- # with dummies.
117- shared_inputs . append ( var )
118- dummy_shared_inputs . append ( var . type ())
119- elif var not in inputs and not isinstance ( var , Constant ):
120- raise MissingInputError ( f"OpFromGraph is missing an input: { var } " )
121-
122- replacements = dict ( zip ( inputs + shared_inputs , dummy_inputs + dummy_shared_inputs ) )
103+ # We allow shared inputs to be added automatically to the graph
104+ implicit_shared_inputs . append ( var )
105+ dummy_implicit_shared_inputs . append ( var . type ())
106+ elif not isinstance ( var , Constant ):
107+ raise MissingInputError ( f"NominalGraph is missing an input: { var } " )
108+
109+ replacements = dict (
110+ zip (
111+ inputs + implicit_shared_inputs , dummy_inputs + dummy_implicit_shared_inputs
112+ )
113+ )
123114
124115 new = rebuild_collect_shared (
125116 cast (Sequence [Variable ], outputs ),
126- inputs = inputs + shared_inputs ,
117+ inputs = inputs + implicit_shared_inputs ,
127118 replace = replacements ,
128119 copy_inputs_over = False ,
129120 )
@@ -133,7 +124,7 @@ def construct_nominal_fgraph(
133124 (clone_d , update_d , update_expr , new_shared_inputs ),
134125 ) = new
135126
136- assert len (local_inputs ) == len (inputs ) + len (shared_inputs )
127+ assert len (local_inputs ) == len (inputs ) + len (implicit_shared_inputs )
137128 assert len (local_outputs ) == len (outputs )
138129 assert not update_d
139130 assert not update_expr
@@ -155,7 +146,7 @@ def construct_nominal_fgraph(
155146 fgraph .clients .pop (inp , None )
156147 fgraph .add_input (nom_inp )
157148
158- return fgraph , shared_inputs , update_d , update_expr
149+ return fgraph , implicit_shared_inputs , update_d , update_expr
159150
160151
161152class OpFromGraph (Op , HasInnerGraph ):
@@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
177168 - grad() make it support DisconnectedType and the new interface
178169 - add support for NullType and DisconnectedType when R_op supports them
179170 - check how it works with updates.
180- - add test with constant as input or inside the inner graph.
181- - Add support for the GPU? Probably just need an opt to remove transfer
182171 - Add support to pickle this Op.
183172 - Add support/test with random generator
184173 - Add optimization to removing unused inputs/outputs
@@ -310,11 +299,13 @@ def __init__(
310299 self ,
311300 inputs : list [Variable ],
312301 outputs : list [Variable ],
302+ * ,
313303 inline : bool = False ,
314304 lop_overrides : str = "default" ,
315305 grad_overrides : str = "default" ,
316306 rop_overrides : str = "default" ,
317307 connection_pattern : Optional [list [list [bool ]]] = None ,
308+ strict : bool = False ,
318309 name : Optional [str ] = None ,
319310 ** kwargs ,
320311 ):
@@ -399,6 +390,8 @@ def __init__(
399390 must be equal to number of outputs. connection_pattern If not
400391 ``None``, this will be used as the connection_pattern for this
401392 :class:`Op`.
393+ strict: bool, default False
394+ Raise if SharedVariables needed to compute the graph are not provided as explicit inputs.
402395 name
403396 A name for debugging purposes.
404397 kwargs
@@ -424,6 +417,12 @@ def __init__(
424417 inputs , outputs
425418 )
426419
420+ if strict and self .shared_inputs :
421+ raise ValueError (
422+ "All shared variables must be provided as inputs under strict=True. "
423+ f"The following variables were missing { self .shared_inputs } "
424+ )
425+
427426 self .kwargs = kwargs
428427 self .input_types = [inp .type for inp in inputs ]
429428 self .output_types = [out .type for out in outputs ]
0 commit comments