From eb4ff877b84c0a55cc5e40467f4e0babbeb08a37 Mon Sep 17 00:00:00 2001 From: Vishal Kushwaha Date: Wed, 18 Mar 2026 18:17:20 +0530 Subject: [PATCH] fix(survival): fix probability mass loss and alignment in _SksurvAdapter - Added boundary handling by prepending S(0)=1.0 to survival curves. - Appended remaining survival probability (tail mass) to np.inf. - Corrected temporal alignment of mass drops to match event times. - Resolves Issue #958: weights now sum to 1.0. --- skpro/survival/adapters/sksurv.py | 53 +++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/skpro/survival/adapters/sksurv.py b/skpro/survival/adapters/sksurv.py index 9d7b43c1e..cc3caa7d5 100644 --- a/skpro/survival/adapters/sksurv.py +++ b/skpro/survival/adapters/sksurv.py @@ -136,22 +136,49 @@ def _predict_proba(self, X): X = X.astype("float") # sksurv insists on float dtype X = prep_skl_df(X) - # predict on X + # predict on X - shape (n_samples, n_times) sksurv_survf = sksurv_est.predict_survival_function(X, return_array=True) - - times = sksurv_est.unique_times_[:-1] - - nt = len(times) - mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel() - - times_val = np.repeat(times, repeats=len(X)) - times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols) - - weights = -np.diff(sksurv_survf, axis=1).flatten() - weights_df = pd.Series(weights, index=mi) + times = sksurv_est.unique_times_ + + # 1. Handle Initial Mass (S(0) = 1.0) + # We prepend 1.0 to the survival curves to capture the first drop + ones = np.ones((sksurv_survf.shape[0], 1)) + surv_extended = np.hstack([ones, sksurv_survf]) + + # 2. Calculate Weights via negative difference + # -np.diff captures the 'drop' in survival, which is the probability mass + weights = -np.diff(surv_extended, axis=1) + + # 3. Handle Tail Mass (Censoring/Remaining mass) + # If the survival function doesn't reach 0, the remaining mass + # is assigned to infinity (representing 'at some point in the future') + tail_mass = sksurv_survf[:, -1:] + final_weights = np.hstack([weights, tail_mass]) + + # 4. Align Times + # We append np.inf as the timestamp for the tail mass + final_times = np.append(times, np.inf) + + # 5. Reshape for Empirical distribution + # The Empirical distribution expects (n_samples * n_points) format for spl + n_samples = len(X) + n_points = len(final_times) + + # Create a MultiIndex for the weights and times + mi = pd.MultiIndex.from_product([X.index, range(n_points)]).swaplevel() + + # Flatten weights and repeat times for each sample + weights_flat = final_weights.flatten() + times_repeated = np.tile(final_times, n_samples) + + times_df = pd.DataFrame(times_repeated, index=mi, columns=self._y_cols) + weights_ser = pd.Series(weights_flat, index=mi) dist = Empirical( - spl=times_df, weights=weights_df, index=X.index, columns=self._y_cols + spl=times_df, + weights=weights_ser, + index=X.index, + columns=self._y_cols ) return dist