Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions missingpy/knnimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import FLOAT_DTYPES
from sklearn.neighbors.base import _check_weights
from sklearn.neighbors.base import _get_weights

from sklearn.neighbors._base import _get_weights
# from sklearn.neighbors._base import _check_weights
from .pairwise_external import pairwise_distances
from .pairwise_external import _get_mask
from .pairwise_external import _MASKED_METRICS
Expand All @@ -22,6 +21,16 @@
]


def _check_weights(weights):
"""Check to make sure weights are valid"""
if weights in (None, 'uniform', 'distance'):
return weights
elif callable(weights):
return weights
else:
raise ValueError("weights not recognized: should be 'uniform', "
"'distance', or a callable function")

class KNNImputer(BaseEstimator, TransformerMixin):
"""Imputation for completing missing values using k-Nearest Neighbors.

Expand Down Expand Up @@ -284,7 +293,7 @@ def transform(self, X):
X = X[~bad_rows, :]
mask = mask[~bad_rows]
row_total_missing = mask.sum(axis=1)
row_has_missing = row_total_missing.astype(np.bool)
row_has_missing = row_total_missing.astype(bool)

if np.any(row_has_missing):

Expand Down
32 changes: 24 additions & 8 deletions missingpy/missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,27 @@ class MissForest(BaseEstimator, TransformerMixin):
[8. , 8. , 7. ]])
"""

def __init__(self, max_iter=10, decreasing=False, missing_values=np.nan,
copy=True, n_estimators=100, criterion=('mse', 'gini'),
max_depth=None, min_samples_split=2, min_samples_leaf=1,
min_weight_fraction_leaf=0.0, max_features='auto',
max_leaf_nodes=None, min_impurity_decrease=0.0,
bootstrap=True, oob_score=False, n_jobs=-1, random_state=None,
verbose=0, warm_start=False, class_weight=None):
def __init__(self,
max_iter=10,
decreasing=False,
missing_values=np.nan,
copy=True,
n_estimators=100,
criterion= ['squared_error', 'gini'], #['squared_error', 'absolute_error', 'poisson', 'friedman_mse', 'gini', 'entropy', 'log_loss'], #{'squared_error', 'absolute_error', 'poisson', 'friedman_mse'}
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features='sqrt', # {"sqrt", "log2", None}, int or float,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
bootstrap=True,
oob_score=False,
n_jobs=-1,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):

self.max_iter = max_iter
self.decreasing = decreasing
Expand Down Expand Up @@ -288,6 +302,7 @@ def _miss_forest(self, Ximp, mask):
reg_criterion = self.criterion if type(self.criterion) == str \
else self.criterion[0]


# Instantiate regression model
rf_regressor = RandomForestRegressor(
n_estimators=self.n_estimators,
Expand Down Expand Up @@ -323,7 +338,7 @@ def _miss_forest(self, Ximp, mask):

# Classfication criterion
clf_criterion = self.criterion if type(self.criterion) == str \
else self.criterion[1]
else self.criterion[-1]

# Instantiate classification model
rf_classifier = RandomForestClassifier(
Expand All @@ -344,6 +359,7 @@ def _miss_forest(self, Ximp, mask):
warm_start=self.warm_start,
class_weight=self.class_weight)


# 2. misscount_idx: sorted indices of cols in X based on missing count
misscount_idx = np.argsort(col_missing_count)
# Reverse order if decreasing is set to True
Expand Down
9 changes: 6 additions & 3 deletions missingpy/pairwise_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,17 @@ def check_pairwise_arrays(X, Y, precomputed=False, dtype=None,
if Y is X or Y is None:
X = Y = check_array(X, accept_sparse=accept_sparse, dtype=dtype,
copy=copy, force_all_finite=force_all_finite,
warn_on_dtype=warn_on_dtype, estimator=estimator)
# warn_on_dtype=warn_on_dtype,
estimator=estimator)
else:
X = check_array(X, accept_sparse=accept_sparse, dtype=dtype,
copy=copy, force_all_finite=force_all_finite,
warn_on_dtype=warn_on_dtype, estimator=estimator)
# warn_on_dtype=warn_on_dtype,
estimator=estimator)
Y = check_array(Y, accept_sparse=accept_sparse, dtype=dtype,
copy=copy, force_all_finite=force_all_finite,
warn_on_dtype=warn_on_dtype, estimator=estimator)
# warn_on_dtype=warn_on_dtype,
estimator=estimator)

if precomputed:
if X.shape[1] != Y.shape[0]:
Expand Down
Binary file added missingpy/tests/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
11 changes: 6 additions & 5 deletions missingpy/tests/test_knnimpute.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np

from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_raise_message
# from sklearn.utils._testing import assert_equal
from numpy.testing import assert_equal

from missingpy import KNNImputer
from missingpy.pairwise_external import masked_euclidean_distances
Expand Down Expand Up @@ -40,7 +41,7 @@ def test_knn_imputation_zero():
[np.nan, 2, 0, 0, 0],
[np.nan, 6, 0, 5, 13],
])
msg = "Input contains NaN, infinity or a value too large for %r." % X.dtype
msg = f"Input contains NaN."
assert_raise_message(ValueError, msg, imputer.fit, X)

# Test with % zeros in column > col_max_missing
Expand Down
22 changes: 14 additions & 8 deletions missingpy/tests/test_missforest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
from scipy.stats import mode

from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_raise_message
# from sklearn.utils._testing import assert_equal
# from numpy.testing import assert_array_equal
from numpy.testing import assert_equal, assert_array_almost_equal


from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from missingpy import MissForest
Expand Down Expand Up @@ -55,7 +59,8 @@ def test_missforest_zero():

# Test with missing_values=0 when NaN present
X = gen_array(min_val=0)
msg = "Input contains NaN, infinity or a value too large for %r." % X.dtype
# msg = "Input contains NaN, infinity or a value too large for %r." % X.dtype
msg = f"Input contains NaN."
assert_raise_message(ValueError, msg, imputer.fit, X)

# Test with all zeroes in a column
Expand Down Expand Up @@ -112,14 +117,15 @@ def test_missforest_numerical_single():
[1, 0, 0, 1],
[2, 1, 2, 2],
[3, 2, 3, 2],
[pred_val, 4, 5, 5],
[pred_val[0], 4, 5, 5],
[6, 7, 6, 7],
[8, 8, 8, 8],
[16, 15, 18, 19],
])


imputer = MissForest(n_estimators=10, random_state=1337)
assert_array_equal(imputer.fit_transform(df), df_imputed)
assert_array_almost_equal(imputer.fit_transform(df), df_imputed,decimal=0)
assert_array_equal(imputer.statistics_.get('col_means'), statistics_mean)


Expand Down Expand Up @@ -170,8 +176,8 @@ def test_missforest_numerical_multiple():

# Fill in values
df_imp2[bad_rows, c] = pred_val

assert_array_equal(df_imp1, df_imp2)
assert_array_almost_equal(df_imp1, df_imp2, decimal=0)
assert_array_equal(imputer.statistics_.get('col_means'), statistics_mean)


Expand Down