[ENH] add truncated_mean interface to BaseDistribution#995
[ENH] add truncated_mean interface to BaseDistribution#995patelchaitany wants to merge 1 commit intosktime:mainfrom
Conversation
JiwaniZakir
left a comment
There was a problem hiding this comment.
The sampling-based fallback in _truncated_mean calls self.sample(1) inside a loop of approx_spl_size (default 1000) iterations, which is orders of magnitude slower than calling self.sample(approx_spl_size) once and filtering. The existing _mean fallback at the end of the file does exactly this with _sample_mean, so the pattern is already established — the loop approach here is inconsistent with the rest of the codebase and will be a significant performance bottleneck.
In _energy_x, the added docstring block (around line 1235) describes the truncated-mean identity used in the new code path, but it's inserted into the existing _energy_x docstring rather than being clearly separated from the original parameter/return docs. This makes the docstring structurally confusing — the formula section appears before the "Parameters" block, which is fine, but it references _truncated_mean and cdf without explaining that this only triggers when both are exact capabilities, which is important context for subclass implementors.
The zero_mass guard in the ppf-based approximation (np.abs(cdf_u - cdf_l) < 1e-15) correctly handles degenerate intervals, but the same guard is absent in the _energy_x shortcut path — if a point x lies exactly at a probability-0 region boundary, right_mean or left_mean could return nan and the nan_to_num(..., nan=0.0) silently masks what might be a meaningful divergence rather than a true zero-contribution interval.
fkiraly
left a comment
There was a problem hiding this comment.
Very nice!
Can you please add a test to TestAllDistributions that checks that the method works? I think the method just has to be added to the right places.
| if spl_arr.ndim == 1: | ||
| spl_arr = spl_arr.reshape(-1, 1, 1) | ||
| elif spl_arr.ndim == 2: | ||
| spl_arr = spl_arr.reshape(spl_arr.shape[0], 1, -1) |
There was a problem hiding this comment.
why is this correct? or is it a mistake?
There was a problem hiding this comment.
Yeah, those reshape branches are dead code on my end. sample(1) always returns a DataFrame, so .values gives 2D, stacking gives 3D neither ndim == 1 nor ndim == 2 ever fires. I added them as defensive guards but they're unnecessary. Happy to remove them.
dde2b4d to
da87e95
Compare
Signed-off-by: Chaitany patel <patelchaitany93@gmail.com>
da87e95 to
4c21cf3
Compare
|
@fkiraly, I have added truncated_mean to METHODS_SCALAR plus two new tests one checks that unbounded truncated_mean() matches mean(), the other verifies output format with actual bounds. |
Reference Issues/PRs
Addresses #221.
What does this implement/fix? Explain your changes.
Adds truncated_mean(lower, upper) to BaseDistribution. Returns E[X | lower < X < upper]. Closed-form implementations for Normal, Exponential, Laplace, and Uniform. Other distributions fall back to a ppf-based numerical approximation.
The base class _energy_x now uses truncated means when available: if a distribution has exact truncated_mean and cdf, it computes E[|X - c|] from the energy identity instead of Monte Carlo sampling.
TruncatedDistribution gets a _mean() that calls the inner distribution's truncated_mean, so TruncatedDistribution(Normal(...), lower=0).mean() is exact now.
Changes
base/_base.py- newtruncated_mean/_truncated_meanwith ppf/MC default, updated_energy_xnormal.py- exact_truncated_meanexponential.py- exact_truncated_meanlaplace.py- exact_truncated_meanuniform.py- exact_truncated_meantruncated.py-_mean()via inner distribution'struncated_mean, updated tag logiDoes your contribution introduce a new dependency? If yes, which one?
No
What should a reviewer concentrate their feedback on?
Did you add any tests for the change?
No
Any other comments?
PR checklist
For all contributions
How to: add yourself to the all-contributors file in the
skproroot directory (not theCONTRIBUTORS.md). Common badges:code- fixing a bug, or adding code logic.doc- writing or improving documentation or docstrings.bug- reporting or diagnosing a bug (get this pluscodeif you also fixed the bug in the PR).maintenance- CI, test framework, release.See here for full badge reference
For new estimators
docs/source/api_reference/taskname.rst, follow the pattern.Examplessection.python_dependenciestag and ensureddependency isolation, see the estimator dependencies guide.