Skip to content

Commit f1e836c

Browse files
committed
Improve error handling and type checks in preprocessing
Enhanced input validation and error messages in curvilinear and unstructured scikit-learn preprocessors, replacing assertions with explicit exceptions. Added missing backend error in Warped2DInterp. Fixed gradient caching bug in MultivaluedInterp by using correct argument key. Ensured interp_kwargs is copied before update in UnstructuredInterp.
1 parent ec427e0 commit f1e836c

5 files changed

Lines changed: 37 additions & 25 deletions

File tree

src/multinterp/curvilinear/_scikit_learn.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,26 @@ def __init__(
106106

107107
feature = pp_options.get("feature", None)
108108

109-
if feature and isinstance(feature, str):
110-
degree = pp_options.get("degree", 3)
111-
assert isinstance(degree, int), "Degree must be an integer."
112-
if feature.startswith("pol"):
113-
pipeline.insert(0, PolynomialFeatures(degree))
114-
elif feature.startswith("spl"):
115-
n_knots = pp_options.get("n_knots", 5)
116-
assert isinstance(n_knots, int), "n_knots must be an integer."
117-
pipeline.insert(0, SplineTransformer(n_knots=n_knots, degree=degree))
118-
else:
119-
msg = f"Feature {feature} not recognized."
120-
raise AttributeError(msg)
109+
if not feature or not isinstance(feature, str):
110+
msg = f"Feature must be a string ('pol' or 'spl'), got {feature!r}."
111+
raise ValueError(msg)
112+
113+
degree = pp_options.get("degree", 3)
114+
if not isinstance(degree, int):
115+
msg = "Degree must be an integer."
116+
raise TypeError(msg)
117+
118+
if feature.startswith("pol"):
119+
pipeline.insert(0, PolynomialFeatures(degree))
120+
elif feature.startswith("spl"):
121+
n_knots = pp_options.get("n_knots", 5)
122+
if not isinstance(n_knots, int):
123+
msg = "n_knots must be an integer."
124+
raise TypeError(msg)
125+
pipeline.insert(0, SplineTransformer(n_knots=n_knots, degree=degree))
121126
else:
122-
msg = f"Feature {feature} not recognized."
123-
raise AttributeError(msg)
127+
msg = f"Feature {feature!r} not recognized. Use 'pol' or 'spl'."
128+
raise ValueError(msg)
124129

125130
if std:
126131
pipeline.insert(0, StandardScaler())

src/multinterp/curvilinear/_warped.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __call__(self, *args, axis=1):
4141
output = self._interp_piecewise(args, axis)
4242
elif self.backend == "numba":
4343
output = self._backend_numba(args, axis)
44+
else:
45+
msg = f"Backend {self.backend!r} not supported for Warped2DInterp."
46+
raise NotImplementedError(msg)
4447

4548
return output
4649

src/multinterp/rectilinear/_multi.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,16 @@ def diff(self, axis=None, argnum=None, edge_order=1):
280280
options=self.mc_kwargs,
281281
)
282282

283-
grad = self._gradient.get((axis, arg))
283+
grad = self._gradient.get((axis, argnum))
284284
if grad is None:
285-
self._gradient[(axis, arg)] = get_grad(
285+
self._gradient[(axis, argnum)] = get_grad(
286286
self.values,
287-
self.grids[axis],
288-
axis=axis,
287+
self.grids[argnum],
288+
axis=argnum,
289289
edge_order=edge_order,
290290
backend=self.backend,
291291
)
292+
grad = self._gradient[(axis, argnum)]
292293

293294
return MultivariateInterp(
294295
grad,

src/multinterp/unstructured/_scikit_learn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,20 @@ def __init__(
9999

100100
if feature and isinstance(feature, str):
101101
degree = options.get("degree", 3)
102-
assert isinstance(degree, int), "Degree must be an integer."
102+
if not isinstance(degree, int):
103+
msg = "Degree must be an integer."
104+
raise TypeError(msg)
103105
if feature.startswith("pol"):
104106
pipeline.insert(0, PolynomialFeatures(degree))
105107
elif feature.startswith("spl"):
106108
n_knots = options.get("n_knots", 5)
107-
assert isinstance(n_knots, int), "n_knots must be an integer."
109+
if not isinstance(n_knots, int):
110+
msg = "n_knots must be an integer."
111+
raise TypeError(msg)
108112
pipeline.insert(0, SplineTransformer(n_knots=n_knots, degree=degree))
109113
else:
110-
msg = f"Feature {feature} not recognized."
111-
raise AttributeError(msg)
114+
msg = f"Feature {feature!r} not recognized. Use 'pol' or 'spl'."
115+
raise ValueError(msg)
112116

113117
if std:
114118
pipeline.insert(0, StandardScaler())

src/multinterp/unstructured/_scipy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,9 @@ def __init__(
9696
msg = f"Unknown interpolation method {method} for {self.ndim} dimensional data."
9797
raise ValueError(msg)
9898

99-
self.interp_kwargs = interp_kwargs
99+
self.interp_kwargs = interp_kwargs.copy()
100100
if options:
101-
self.interp_kwargs.copy()
102-
intersection = interp_kwargs.keys() & options.keys()
101+
intersection = self.interp_kwargs.keys() & options.keys()
103102
self.interp_kwargs.update({key: options[key] for key in intersection})
104103

105104
self.interpolator = interpolator_class(

0 commit comments

Comments
 (0)