Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions missingpy/knnimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .pairwise_external import _get_mask
from .pairwise_external import _MASKED_METRICS

from .utils import is_nan

__all__ = [
'KNNImputer',
]
Expand Down Expand Up @@ -193,8 +195,7 @@ def fit(self, X, y=None):
"""

# Check data integrity and calling arguments
force_all_finite = False if self.missing_values in ["NaN",
np.nan] else True
force_all_finite = not is_nan(self.missing_values)
if not force_all_finite:
if self.metric not in _MASKED_METRICS and not callable(
self.metric):
Expand Down Expand Up @@ -250,8 +251,7 @@ def transform(self, X):
"""

check_is_fitted(self, ["fitted_X_", "statistics_"])
force_all_finite = False if self.missing_values in ["NaN",
np.nan] else True
force_all_finite = not is_nan(self.missing_values)
X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES,
force_all_finite=force_all_finite, copy=self.copy)

Expand Down
9 changes: 4 additions & 5 deletions missingpy/missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .pairwise_external import _get_mask

from .utils import is_nan

__all__ = [
'MissForest',
]
Expand Down Expand Up @@ -434,9 +436,7 @@ def fit(self, X, y=None, cat_vars=None):
"""

# Check data integrity and calling arguments
force_all_finite = False if self.missing_values in ["NaN",
np.nan] else True

force_all_finite = not is_nan(self.missing_values)
X = check_array(X, accept_sparse=False, dtype=np.float64,
force_all_finite=force_all_finite, copy=self.copy)

Expand Down Expand Up @@ -499,8 +499,7 @@ def transform(self, X):
check_is_fitted(self, ["cat_vars_", "num_vars_", "statistics_"])

# Check data integrity
force_all_finite = False if self.missing_values in ["NaN",
np.nan] else True
force_all_finite = not is_nan(self.missing_values)
X = check_array(X, accept_sparse=False, dtype=np.float64,
force_all_finite=force_all_finite, copy=self.copy)

Expand Down
4 changes: 2 additions & 2 deletions missingpy/pairwise_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
from sklearn.metrics.pairwise import _parallel_pairwise
from sklearn.utils import check_array

from .utils import masked_euclidean_distances
from .utils import is_nan, masked_euclidean_distances

_MASKED_METRICS = ['masked_euclidean']
_VALID_METRICS += ['masked_euclidean']


def _get_mask(X, value_to_mask):
"""Compute the boolean mask X == missing_values."""
if value_to_mask == "NaN" or np.isnan(value_to_mask):
if is_nan(value_to_mask):
return np.isnan(X)
else:
return X == value_to_mask
Expand Down
4 changes: 3 additions & 1 deletion missingpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

def is_nan(n):
return n == "NaN" or isinstance(n, float) and np.isnan(n)

def masked_euclidean_distances(X, Y=None, squared=False,
missing_values="NaN", copy=True):
Expand Down Expand Up @@ -92,7 +94,7 @@ def masked_euclidean_distances(X, Y=None, squared=False,
raise ValueError("One or more rows only contain missing values.")

# else:
if missing_values not in ["NaN", np.nan] and (
if not is_nan(missing_values) and (
np.any(np.isnan(X)) or (Y is not X and np.any(np.isnan(Y)))):
raise ValueError(
"NaN values present but missing_value = {0}".format(
Expand Down