1818import types
1919import warnings
2020
21- from collections .abc import Callable , Iterable , Sequence
21+ from collections .abc import Iterable , Sequence
2222from sys import modules
2323from typing import (
2424 TYPE_CHECKING ,
2727 Optional ,
2828 TypeVar ,
2929 cast ,
30+ overload ,
3031)
3132
3233import numpy as np
3536import pytensor .tensor as pt
3637import scipy .sparse as sps
3738
38- from pytensor .compile import DeepCopyOp , get_mode
39+ from pytensor .compile import DeepCopyOp , Function , get_mode
3940from pytensor .compile .sharedvalue import SharedVariable
4041from pytensor .graph .basic import Constant , Variable , graph_inputs
4142from pytensor .scalar import Cast
@@ -1524,6 +1525,28 @@ def replace_rvs_by_values(
15241525 rvs_to_transforms = self .rvs_to_transforms ,
15251526 )
15261527
1528+ @overload
1529+ def compile_fn (
1530+ self ,
1531+ outs : Variable | Sequence [Variable ],
1532+ * ,
1533+ inputs : Sequence [Variable ] | None = None ,
1534+ mode = None ,
1535+ point_fn : Literal [True ] = True ,
1536+ ** kwargs ,
1537+ ) -> PointFunc : ...
1538+
1539+ @overload
1540+ def compile_fn (
1541+ self ,
1542+ outs : Variable | Sequence [Variable ],
1543+ * ,
1544+ inputs : Sequence [Variable ] | None = None ,
1545+ mode = None ,
1546+ point_fn : Literal [False ],
1547+ ** kwargs ,
1548+ ) -> Function : ...
1549+
15271550 def compile_fn (
15281551 self ,
15291552 outs : Variable | Sequence [Variable ],
@@ -1532,7 +1555,7 @@ def compile_fn(
15321555 mode = None ,
15331556 point_fn : bool = True ,
15341557 ** kwargs ,
1535- ) -> PointFunc | Callable [[ Sequence [ np . ndarray ]], Sequence [ np . ndarray ]] :
1558+ ) -> PointFunc | Function :
15361559 """Compiles an PyTensor function
15371560
15381561 Parameters
@@ -2044,7 +2067,7 @@ def compile_fn(
20442067 point_fn : bool = True ,
20452068 model : Model | None = None ,
20462069 ** kwargs ,
2047- ) -> PointFunc | Callable [[ Sequence [ np . ndarray ]], Sequence [ np . ndarray ]] :
2070+ ) -> PointFunc | Function :
20482071 """Compiles an PyTensor function
20492072
20502073 Parameters
0 commit comments