@@ -211,7 +211,7 @@ def test_neg(self):
211211 self .assertIsInstance (v , Var )
212212 self .assertEqual (["X" ], v .parent .input_names )
213213 s = str (v )
214- self .assertEqual ("X:FLOAT" , s )
214+ self .assertEqual ("X:FLOAT:[] " , s )
215215 onx = start ().vin ("X" ).Neg ().rename ("Y" ).vout ().to_onnx ()
216216 self .assertIsInstance (onx , ModelProto )
217217 ref = ReferenceEvaluator (onx )
@@ -510,7 +510,23 @@ def ah(self):
510510 expected = (a > 0 ).astype (int ).astype (np .float32 ).reshape ((- 1 , 1 ))
511511 self .assertEqualArray (expected , got )
512512
513+ def test_input_shape (self ):
514+ kernel = (np .arange (9 ) + 1 ).reshape (3 , 3 ).astype (np .float32 )
515+ model = (
516+ start ()
517+ .vin ("X" , shape = [None , None ])
518+ .cst (kernel [np .newaxis , np .newaxis , ...])
519+ .rename ("W" )
520+ .bring ("X" , "W" )
521+ .Conv (pads = [1 , 1 , 1 , 1 ])
522+ .rename ("Y" )
523+ .vout (shape = [])
524+ .to_onnx ()
525+ )
526+ i = str (model .graph .input [0 ]).replace ("\n " , "" ).replace (" " , "" )
527+ self .assertNotIn ("shape{}" , i )
528+
513529
514530if __name__ == "__main__" :
515- TestLightApi ().test_domain ()
531+ TestLightApi ().test_add ()
516532 unittest .main (verbosity = 2 )
0 commit comments