@@ -352,7 +352,14 @@ def __setstate__(self, state):
352352 if isinstance (optimizer , str ) or optimizer is None :
353353 optimizer = predefined_optimizers [optimizer ]
354354 if isinstance (optimizer , RewriteDatabaseQuery ):
355+ # TODO: From the __init__ signature this should always be the case
356+ # But some tests and internal logic allow passing a GraphRewriter directly as optimizer
357+ # Cleanup!
355358 self .provided_optimizer = optimizer
359+ if r := linker .required_rewrites :
360+ optimizer = optimizer .including (* r )
361+ if r := linker .incompatible_rewrites :
362+ optimizer = optimizer .excluding (* r )
356363 self ._optimizer = optimizer
357364 self .call_time = 0
358365 self .fn_time = 0
@@ -365,14 +372,13 @@ def __str__(self):
365372 f"optdb={ self .optdb } )"
366373 )
367374
368- def __get_optimizer (self ):
375+ @property
376+ def optimizer (self ):
369377 if isinstance (self ._optimizer , RewriteDatabaseQuery ):
370378 return self .optdb .query (self ._optimizer )
371379 else :
372380 return self ._optimizer
373381
374- optimizer = property (__get_optimizer )
375-
376382 def get_linker_optimizer (self , linker , optimizer ):
377383 if isinstance (linker , str ) or linker is None :
378384 linker = predefined_linkers [linker ]
@@ -466,61 +472,21 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
466472
467473NUMBA = Mode (
468474 NumbaLinker (),
469- RewriteDatabaseQuery (
470- include = ["fast_run" , "numba" ],
471- exclude = [
472- "cxx_only" ,
473- "BlasOpt" ,
474- "local_careduce_fusion" ,
475- "scan_save_mem_prealloc" ,
476- ],
477- ),
475+ RewriteDatabaseQuery (include = ["fast_run" , "numba" ]),
478476)
479477
480478JAX = Mode (
481479 JAXLinker (),
482- RewriteDatabaseQuery (
483- include = ["fast_run" , "jax" ],
484- exclude = [
485- "cxx_only" ,
486- "BlasOpt" ,
487- "fusion" ,
488- "inplace" ,
489- "scan_save_mem_prealloc" ,
490- # There are specific variants for the LU decompositions supported by JAX
491- "reuse_lu_decomposition_multiple_solves" ,
492- "scan_split_non_sequence_lu_decomposition_solve" ,
493- ],
494- ),
480+ RewriteDatabaseQuery (include = ["fast_run" , "jax" ]),
495481)
496482PYTORCH = Mode (
497483 PytorchLinker (),
498- RewriteDatabaseQuery (
499- include = ["fast_run" ],
500- exclude = [
501- "cxx_only" ,
502- "BlasOpt" ,
503- "fusion" ,
504- "inplace" ,
505- "scan_save_mem_prealloc" ,
506- "reuse_lu_decomposition_multiple_solves" ,
507- "scan_split_non_sequence_lu_decomposition_solve" ,
508- ],
509- ),
484+ RewriteDatabaseQuery (include = ["fast_run" ]),
510485)
511486
512487MLX = Mode (
513488 MLXLinker (),
514- RewriteDatabaseQuery (
515- include = ["fast_run" ],
516- exclude = [
517- "cxx_only" ,
518- "BlasOpt" ,
519- "fusion" ,
520- "inplace" ,
521- "scan_save_mem_prealloc" ,
522- ],
523- ),
489+ RewriteDatabaseQuery (include = ["fast_run" ]),
524490)
525491
526492
0 commit comments