Skip to content

Commit f4ef92c

Browse files
MilesCranmerMilesCranmerBot
authored andcommitted
Merge pull request MilesCranmer#1058 from MilesCranmer/fix-656
fix: torch export with constant arguments
1 parent d42159f commit f4ef92c

2 files changed

Lines changed: 29 additions & 9 deletions

File tree

pysr/export_torch.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
9797
self._torch_func = lambda: self._value
9898
self._args = ()
9999
elif issubclass(expr.func, sympy.Rational):
100-
# This is some fraction fixed in the operator.
101-
self._value = float(expr)
100+
# Includes Integer, since Integer is a subclass of Rational
101+
self.register_buffer("_value", torch.tensor(float(expr)))
102102
self._torch_func = lambda: self._value
103103
self._args = ()
104104
elif issubclass(expr.func, sympy.UnevaluatedExpr):
@@ -111,15 +111,9 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
111111
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
112112
self._torch_func = lambda: self._value
113113
self._args = ()
114-
elif issubclass(expr.func, sympy.Integer):
115-
# Can get here if expr is one of the Integer special cases,
116-
# e.g. NegativeOne
117-
self._value = int(expr)
118-
self._torch_func = lambda: self._value
119-
self._args = ()
120114
elif issubclass(expr.func, sympy.NumberSymbol):
121115
# Can get here from exp(1) or exact pi
122-
self._value = float(expr)
116+
self.register_buffer("_value", torch.tensor(float(expr)))
123117
self._torch_func = lambda: self._value
124118
self._args = ()
125119
elif issubclass(expr.func, sympy.Symbol):

pysr/test/test_torch.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,32 @@ def test_issue_656(self):
192192
decimal=3,
193193
)
194194

195+
def test_constant_arguments(self):
196+
# Test that functions with constant arguments work correctly
197+
# Regression test for https://github.com/MilesCranmer/PySR/issues/656
198+
test_cases = [
199+
(pysr.export_sympy.pysr2sympy("sqrt(2)"), np.sqrt(2)),
200+
(sympy.exp(2), np.exp(2)),
201+
(sympy.log(4), np.log(4)),
202+
(sympy.sin(1), np.sin(1)),
203+
]
204+
205+
for expr, expected in test_cases:
206+
m = pysr.export_torch.sympy2torch(expr, [])
207+
result = m(self.torch.randn(10, 1))
208+
np.testing.assert_almost_equal(result.item(), expected, decimal=3)
209+
210+
# Test with variables: sqrt(2) * x
211+
x = sympy.symbols("x")
212+
expr = sympy.sqrt(2) * x
213+
m = pysr.export_torch.sympy2torch(expr, [x])
214+
X = np.random.randn(10, 1)
215+
np.testing.assert_almost_equal(
216+
m(self.torch.tensor(X)).detach().numpy().flatten(),
217+
np.sqrt(2) * X[:, 0],
218+
decimal=3,
219+
)
220+
195221
def test_feature_selection_custom_operators(self):
196222
rstate = np.random.RandomState(0)
197223
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})

0 commit comments

Comments
 (0)