Conversation
…e edata_mini_3D_missing_values
|
Small note: After our discussion, I used vertical stacking as reshaping strategy. I decided not to use the ehrapy decorator
Also here is an evaluation script comparing the two 3D stacking strategies (horizontal vs vertical), where RMSE and MAE of imputed positions in a dataset with 40% missingness are computed against the corresponding positions in the complete dataset to assess imputation quality. import numpy as np
import ehrdata as ed
import ehrapy as ep
from typing import Iterable
from ehrdata.core.constants import FEATURE_TYPE_KEY, NUMERIC_TAG
def _knn_impute_with_mode(edata, var_names, n_neighbors, layer, temporal_mode):
edata = edata.copy()
from fknni import FastKNNImputer
imputer = FastKNNImputer(n_neighbors=n_neighbors)
if var_names is None:
var_names = edata.var_names
var_indices = edata.var_names.get_indexer(var_names).tolist()
numerical_var_names = edata.var_names[edata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]
numerical_indices = edata.var_names.get_indexer(numerical_var_names).tolist()
X = edata.X if layer is None else edata.layers[layer]
var_indices_original = var_indices
is_3d = False
if X.ndim == 3:
is_3d = True
n_obs, n_vars, n_t = X.shape
if temporal_mode == "vertical":
X = X[:, var_indices, :].astype("float64").transpose(0, 2, 1).reshape(n_obs * n_t, len(var_indices))
else: # horizontal
X = X[:, var_indices, :].astype("float64").reshape(n_obs, len(var_indices) * n_t)
numerical_indices = list(range(X.shape[1]))
var_indices = numerical_indices
complete_numerical_columns = np.array(numerical_indices)[
~np.isnan(X[:, numerical_indices]).any(axis=0)
].tolist()
imputer_data_indices = var_indices + [i for i in complete_numerical_columns if i not in var_indices]
imputer_x = X[:, imputer_data_indices].astype("float64")
X_imputed = imputer.fit_transform(imputer_x)
if is_3d:
if temporal_mode == "vertical":
X_imputed = X_imputed[:, :len(var_indices_original)].reshape(n_obs, n_t, len(var_indices_original)).transpose(0, 2, 1)
else:
X_imputed = X_imputed[:, :len(var_indices_original) * n_t].reshape(n_obs, len(var_indices_original), n_t)
edata.layers[layer][:, var_indices_original, :] = X_imputed
else:
if layer is None:
edata.X[:, imputer_data_indices] = X_imputed
else:
edata.layers[layer][:, imputer_data_indices] = X_imputed
return edata
edata_complete = ed.dt.ehrdata_blobs(
n_variables=10, missing_values=0.0, n_observations=100,
base_timepoints=10, random_state=42, seasonality=True
)
truth = edata_complete.layers["tem_data"]
edata = ed.dt.ehrdata_blobs(
n_variables=10, missing_values=0.4, n_observations=100,
base_timepoints=10, random_state=42, seasonality=True
)
ed.infer_feature_types(edata)
mask = np.isnan(edata.layers["tem_data"])
def evaluate(imputed, truth, mask, label):
diff = imputed[mask] - truth[mask]
rmse = np.sqrt(np.mean(diff**2))
mae = np.mean(np.abs(diff))
scale = truth[mask].max() - truth[mask].min()
nrmse = rmse / scale
nmae = mae / scale
print(f"{label:12s} RMSE: {rmse:.4f}, MAE: {mae:.4f}, NRMSE: {nrmse:.4f}, NMAE: {nmae:.4f}")
for mode in ["vertical", "horizontal"]:
edata_imputed = _knn_impute_with_mode(edata, var_names=None, n_neighbors=5, layer="tem_data", temporal_mode=mode)
evaluate(edata_imputed.layers["tem_data"], truth, mask, mode)Output: |
eroell
left a comment
There was a problem hiding this comment.
I have not much more add, all the earlier discussion points are well described in your PR comments and I think this is a nice cleanup of this imputation function
eroell
left a comment
There was a problem hiding this comment.
Cool!
Can you in the documentation
- A quick preview where dataset with missing values, then the imputation, and then the missing values gone are shown? Either ehrdata blobs or one of the pre-loaded datasets ideally not requiring any other import than ehrdata and ehrapy
- Add 1-2 sentences on how this works in the 2D vs 3D case?
No need to re-request review afterwards, when this is done you can merge!
fixes #947
ep.pp.knn_impute()is extended to support 3D EHRdata.Description of changes
Technical details
For 3D array with shape n_obs, n_var, n_t: