diff --git a/src/fasterrisk/binarization_util.py b/src/fasterrisk/binarization_util.py index a081736..ac9b9f8 100644 --- a/src/fasterrisk/binarization_util.py +++ b/src/fasterrisk/binarization_util.py @@ -142,7 +142,10 @@ def fit(self, df: pd.DataFrame) -> None: for col_idx in range(len(self.cols)): col = self.cols[col_idx] col_value = df[col] - + + # Initialize tmp_num_thresholds *before* checking for NaNs + tmp_num_thresholds = self.max_num_thresholds_per_feature + if col_value.isnull().sum() > 0: tmp_num_thresholds -= 1 binarizers.append({ # need to keep track of NaN for every column @@ -267,4 +270,4 @@ def transform(self, df: pd.DataFrame) -> tuple: def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame: '''fit and transform on same dataframe''' self.fit(df) - return self.transform(df) \ No newline at end of file + return self.transform(df)