Skip to content

Commit e274e0d

Browse files
committed
Allow linkers to specify required/incompatible rewrites
1 parent f5bf2af commit e274e0d

File tree

8 files changed

+72
-50
lines changed

8 files changed

+72
-50
lines changed

pytensor/compile/debugmode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,11 @@ def printstuff(self):
13311331
# the external requirements of the .linker attribute of a mode
13321332
# 1) it's a class instance
13331333
# 2) it a has a .clone() method
1334+
# 3) it has required_rewrites and incompatible_rewrites class attributes
13341335
class _DummyLinker:
1336+
required_rewrites = ()
1337+
incompatible_rewrites = ()
1338+
13351339
# This is not a real linker anyway
13361340
def clone(self, allow_gc=None):
13371341
return self

pytensor/compile/mode.py

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,14 @@ def __setstate__(self, state):
352352
if isinstance(optimizer, str) or optimizer is None:
353353
optimizer = predefined_optimizers[optimizer]
354354
if isinstance(optimizer, RewriteDatabaseQuery):
355+
# TODO: From the __init__ signature this should always be the case
356+
# But some tests and internal logic allow passing a GraphRewriter directly as optimizer
357+
# Cleanup!
355358
self.provided_optimizer = optimizer
359+
if r := linker.required_rewrites:
360+
optimizer = optimizer.including(*r)
361+
if r := linker.incompatible_rewrites:
362+
optimizer = optimizer.excluding(*r)
356363
self._optimizer = optimizer
357364
self.call_time = 0
358365
self.fn_time = 0
@@ -365,14 +372,13 @@ def __str__(self):
365372
f"optdb={self.optdb})"
366373
)
367374

368-
def __get_optimizer(self):
375+
@property
376+
def optimizer(self):
369377
if isinstance(self._optimizer, RewriteDatabaseQuery):
370378
return self.optdb.query(self._optimizer)
371379
else:
372380
return self._optimizer
373381

374-
optimizer = property(__get_optimizer)
375-
376382
def get_linker_optimizer(self, linker, optimizer):
377383
if isinstance(linker, str) or linker is None:
378384
linker = predefined_linkers[linker]
@@ -466,61 +472,21 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
466472

467473
NUMBA = Mode(
468474
NumbaLinker(),
469-
RewriteDatabaseQuery(
470-
include=["fast_run", "numba"],
471-
exclude=[
472-
"cxx_only",
473-
"BlasOpt",
474-
"local_careduce_fusion",
475-
"scan_save_mem_prealloc",
476-
],
477-
),
475+
RewriteDatabaseQuery(include=["fast_run", "numba"]),
478476
)
479477

480478
JAX = Mode(
481479
JAXLinker(),
482-
RewriteDatabaseQuery(
483-
include=["fast_run", "jax"],
484-
exclude=[
485-
"cxx_only",
486-
"BlasOpt",
487-
"fusion",
488-
"inplace",
489-
"scan_save_mem_prealloc",
490-
# There are specific variants for the LU decompositions supported by JAX
491-
"reuse_lu_decomposition_multiple_solves",
492-
"scan_split_non_sequence_lu_decomposition_solve",
493-
],
494-
),
480+
RewriteDatabaseQuery(include=["fast_run", "jax"]),
495481
)
496482
PYTORCH = Mode(
497483
PytorchLinker(),
498-
RewriteDatabaseQuery(
499-
include=["fast_run"],
500-
exclude=[
501-
"cxx_only",
502-
"BlasOpt",
503-
"fusion",
504-
"inplace",
505-
"scan_save_mem_prealloc",
506-
"reuse_lu_decomposition_multiple_solves",
507-
"scan_split_non_sequence_lu_decomposition_solve",
508-
],
509-
),
484+
RewriteDatabaseQuery(include=["fast_run"]),
510485
)
511486

512487
MLX = Mode(
513488
MLXLinker(),
514-
RewriteDatabaseQuery(
515-
include=["fast_run"],
516-
exclude=[
517-
"cxx_only",
518-
"BlasOpt",
519-
"fusion",
520-
"inplace",
521-
"scan_save_mem_prealloc",
522-
],
523-
),
489+
RewriteDatabaseQuery(include=["fast_run"]),
524490
)
525491

526492

pytensor/link/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class Linker(ABC):
157157
the FunctionGraph.
158158
"""
159159

160+
required_rewrites: tuple[str, ...] = ("minimum_compile",)
161+
incompatible_rewrites: tuple[str, ...] = ()
162+
160163
def __init__(
161164
self,
162165
*,

pytensor/link/jax/linker.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,22 @@
99
class JAXLinker(JITLinker):
1010
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1111

12+
required_rewrites = (
13+
"minimum_compile",
14+
"jax",
15+
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
16+
incompatible_rewrites = (
17+
"cxx",
18+
"BlasOpt",
19+
"local_careduce_fusion",
20+
"scan_save_mem_prealloc",
21+
# JAX does it his own inplace optimization
22+
"inplace",
23+
# There are specific variants for the LU decompositions supported by JAX
24+
"reuse_lu_decomposition_multiple_solves",
25+
"scan_split_non_sequence_lu_decomposition_solve",
26+
)
27+
1228
scalar_shape_inputs: tuple[int, ...]
1329

1430
def __init__(self, *args, **kwargs):

pytensor/link/mlx/linker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
class MLXLinker(JITLinker):
55
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
66

7+
incompatible_rewrites = (
8+
"cxx_only",
9+
"BlasOpt",
10+
"fusion",
11+
"inplace",
12+
"scan_save_mem_prealloc",
13+
)
14+
715
def __init__(self, use_compile=True, *args, **kwargs):
816
super().__init__(*args, **kwargs)
917
self.gen_functors = []

pytensor/link/numba/linker.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33

44
class NumbaLinker(JITLinker):
5+
required_rewrites = (
6+
"minimum_compile",
7+
"numba",
8+
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
9+
incompatible_rewrites = (
10+
"cxx",
11+
"BlasOpt",
12+
"local_careduce_fusion",
13+
"scan_save_mem_prealloc",
14+
)
15+
516
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
617

718
def fgraph_convert(self, fgraph, **kwargs):

pytensor/link/pytorch/linker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
class PytorchLinker(JITLinker):
66
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
77

8+
incompatible_rewrites = (
9+
"cxx_only",
10+
"BlasOpt",
11+
"fusion",
12+
"inplace",
13+
"scan_save_mem_prealloc",
14+
"reuse_lu_decomposition_multiple_solves",
15+
"scan_split_non_sequence_lu_decomposition_solve",
16+
)
17+
818
def __init__(self, *args, **kwargs):
919
super().__init__(*args, **kwargs)
1020
self.gen_functors = []

tests/compile/test_mode.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,15 @@ def test_NoOutputFromInplace():
5555

5656

5757
def test_including():
58-
mode = Mode(optimizer="merge")
59-
assert set(mode._optimizer.include) == {"merge"}
58+
mode = Mode(linker="py", optimizer="merge")
59+
assert set(mode._optimizer.include) == {"minimum_compile", "merge"}
6060

6161
new_mode = mode.including("fast_compile")
62-
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
62+
assert set(new_mode._optimizer.include) == {
63+
"minimum_compile",
64+
"merge",
65+
"fast_compile",
66+
}
6367

6468

6569
class TestBunchOfModes:

0 commit comments

Comments
 (0)