Skip to content
74 changes: 8 additions & 66 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ If you prefer an R implementation of this package, have a look at [loreplotr](ht

## Why use lorepy ?

Lorepy offers distinct advantages over traditional methods like stacked bar plots. By employing a linear model, Lorepy
captures overall trends across the entire feature range. It avoids arbitrary cut-offs and segmentation, enabling the
Lorepy offers distinct advantages over traditional methods like stacked bar plots. By employing a linear model, Lorepy
captures overall trends across the entire feature range. It avoids arbitrary cut-offs and segmentation, enabling the
visualization of uncertainty throughout the data range.

You can find examples of the Iris data visualized using stacked bar plots [here](https://github.com/raeslab/lorepy/blob/main/docs/lorepy_vs_bar_plots.md) for comparison.

## How lorepy works

For details on the model mechanics and how to interpret loreplots, see [How lorepy works](https://github.com/raeslab/lorepy/blob/main/docs/how_lorepy_works.md).

## Installation

Lorepy can be installed using pip using the command below.
Expand Down Expand Up @@ -184,7 +188,7 @@ plt.show()

From loreplots it isn't possible to assess how certain we are of the prevalence of each group across the range. To
provide a view into this there is a function ```uncertainty_plot```, which can be used as shown below. This will use
```resampling``` (or ```jackknifing```) to determine the 50% and 95% interval of predicted values and show these in a
```resampling``` (or ```random subsampling```) to determine the 50% and 95% interval of predicted values and show these in a
multi-panel plot with one plot per category.

```python
Expand All @@ -205,69 +209,7 @@ This also supports custom colors, ranges and classifiers. More examples are avai

### Feature Importance Analysis

Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** to test whether the relationship you see in your loreplot is statistically significant.

#### How it Works

The function uses a robust resampling approach combined with sklearn's optimized permutation importance:

1. **Bootstrap/Jackknife Sampling**: Creates multiple subsamples of your data (default: 100 iterations)
2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage
3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact
4. **Performance Assessment**: Measures accuracy drop using statistically sound train/test splits
5. **Statistical Summary**: Aggregates results across all iterations to provide confidence intervals and significance testing

This approach works with **any sklearn classifier** (LogisticRegression, SVM, RandomForest, etc.) and properly handles confounders by keeping them constant during shuffling. The implementation uses sklearn's battle-tested permutation importance algorithm for reliable, unbiased results.

```python
from lorepy import feature_importance

# Basic usage
stats = feature_importance(data=iris_df, x="sepal width (cm)", y="species", iterations=100)
print(stats['interpretation'])
# Output: "Feature importance: 0.2019 ± 0.0433. Positive in 100.0% of iterations (p=0.0000)"
```

#### Understanding the Output

The function returns a dictionary with the following key statistics:

- **`mean_importance`**: Average accuracy drop when x-feature is shuffled (higher = more important)
- **`std_importance`**: Standard deviation of importance across iterations
- **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate
- **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction)
- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction)
- **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant)
- **`interpretation`**: Human-readable summary of the results

#### Advanced Usage

```python
from sklearn.svm import SVC

# With confounders and custom classifier
stats = feature_importance(
data=data,
x="age",
y="disease",
confounders=[("bmi", 25), ("sex", "female")], # Control for these variables
clf=SVC(probability=True), # Use SVM instead of logistic regression
mode="jackknife", # Use jackknife instead of bootstrap
iterations=200 # More iterations for precision
)

print(f"P-value: {stats['p_value']:.4f}")
print(f"95% CI: [{stats['importance_95ci_low']:.3f}, {stats['importance_95ci_high']:.3f}]")
```

#### Interpretation Guidelines

- **Strong Association**: `p_value < 0.01`, `proportion_positive > 95%`
- **Moderate Association**: `p_value < 0.05`, `proportion_positive > 80%`
- **Weak/No Association**: `p_value > 0.05`, confidence interval includes zero
- **Negative Association**: `proportion_negative > proportion_positive` (unusual but possible)


For details on permutation-based feature importance analysis, see [Feature Importance Analysis](https://github.com/raeslab/lorepy/blob/main/docs/feature_importance.md).

## Development

Expand Down
67 changes: 67 additions & 0 deletions docs/feature_importance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Feature Importance Analysis

Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** with **log loss (cross-entropy)** as the scoring metric to test whether the relationship you see in your loreplot is statistically significant. Log loss evaluates the full predicted probability distribution rather than just hard class predictions, making it well-suited for lorepy's probability-based visualizations.

## How it Works

The function uses a robust resampling approach combined with sklearn's optimized permutation importance:

1. **Bootstrap/Random Subsampling**: Creates multiple subsamples of your data (default: 100 iterations)
2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage
3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact
4. **Performance Assessment**: Measures log loss increase using statistically sound train/test splits
5. **Statistical Summary**: Aggregates results across all iterations to provide confidence intervals and significance testing

This approach works with **any sklearn classifier** (LogisticRegression, SVM, RandomForest, etc.) and properly handles confounders by keeping them constant during shuffling. The implementation uses sklearn's battle-tested permutation importance algorithm for reliable, unbiased results.

```python
from lorepy import feature_importance

# Basic usage
stats = feature_importance(data=iris_df, x="sepal width (cm)", y="species", iterations=100)
print(stats['interpretation'])
# Output: "Feature importance: 0.2019 ± 0.0433. Positive in 100.0% of iterations (p=0.0000)"
```

## Understanding the Output

The function returns a dictionary with the following key statistics:

- **`mean_importance`**: Average log loss increase when x-feature is shuffled (higher = more important)
- **`std_importance`**: Standard deviation of importance across iterations
- **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate
- **`mean_validation_log_loss`**: Mean log loss on the validation data across iterations (lower = better)
- **`std_validation_log_loss`**: Standard deviation of the validation log loss
- **`mean_permuted_log_loss`**: Mean log loss on the permuted data across iterations (lower = better)
- **`std_permuted_log_loss`**: Standard deviation of the permuted log loss
- **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction)
- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction)
- **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant)
- **`interpretation`**: Human-readable summary of the results

## Advanced Usage

```python
from sklearn.svm import SVC

# With confounders and custom classifier
stats = feature_importance(
data=data,
x="age",
y="disease",
confounders=[("bmi", 25), ("sex", "female")], # Control for these variables
clf=SVC(probability=True), # Use SVM instead of logistic regression
mode="random_subsampling", # Use random subsampling instead of bootstrap
iterations=200 # More iterations for precision
)

print(f"P-value: {stats['p_value']:.4f}")
print(f"95% CI: [{stats['importance_95ci_low']:.3f}, {stats['importance_95ci_high']:.3f}]")
```

## Interpretation Guidelines

- **Strong Association**: `p_value < 0.01`, `proportion_positive > 95%`
- **Moderate Association**: `p_value < 0.05`, `proportion_positive > 80%`
- **Weak/No Association**: `p_value > 0.05`, confidence interval includes zero
- **Negative Association**: `proportion_negative > proportion_positive` (unusual but possible)
24 changes: 24 additions & 0 deletions docs/how_lorepy_works.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# How lorepy works

Under the hood, lorepy fits scikit-learn's default `LogisticRegression()` model (L2 regularization, `C=1.0`) unless you
pass a custom classifier. For a multiclass outcome with *K* classes, the fitted model can be written as coefficient
vectors **β**\_k and intercepts β\_{k0}, with class probabilities:

*P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0})*

These probabilities sum to one at each *x*, so they can be drawn directly as a stacked area chart. Lorepy evaluates
the fitted model at 200 evenly spaced points across the observed range of the x-feature, yielding smooth curves for the
estimated class composition as a function of *x*. If confounders are provided, they are included during fitting and
then fixed to user-specified reference values at prediction time, so the displayed curves are **conditional on those
reference values** (not marginal averages over the confounder distribution). Sample dots are generated by drawing each
point's *y*-coordinate uniformly within the predicted probability interval of its true class, which visualizes class
membership while preserving the stacked-probability interpretation.

Concretely, the height of each colored band at a given *x*-value represents the model's estimated proportion of that
class: a band spanning 60% of the y-axis means the model estimates that class accounts for 60% of observations at that
point along the feature. As *x* increases, bands that widen indicate classes with a growing estimated proportion, while
narrowing bands indicate classes becoming rarer. A class that dominates the plot across the full range has a
consistently high estimated proportion regardless of the feature value, whereas a sharp crossover between two bands
pinpoints where one class overtakes another. Because the bands always sum to one, the plot naturally encodes a zero-sum
trade-off: one class can only grow in estimated proportion at the expense of others, making it straightforward to read
both absolute and relative shifts directly from the visualization.
12 changes: 8 additions & 4 deletions example_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,24 @@
print(stats)

stats = feature_importance(
data=iris_df, x="sepal width (cm)", y="species", iterations=100, mode="jackknife"
data=iris_df,
x="sepal width (cm)",
y="species",
iterations=100,
mode="random_subsampling",
)
print(stats)


# Using jackknife instead of resample to assess uncertainty
# Using random subsampling instead of resample to assess uncertainty
uncertainty_plot(
data=iris_df,
x="sepal width (cm)",
y="species",
iterations=100,
jackknife_fraction=0.8,
subsampling_fraction=0.8,
)
plt.savefig("./docs/img/uncertainty_jackknife.png", dpi=150)
plt.savefig("./docs/img/uncertainty_random_subsampling.png", dpi=150)
plt.show()

# Uncertainty plot with custom colors
Expand Down
Loading