From 1538fc7de4f229e465f5f6af0794ddb9bd3d8319 Mon Sep 17 00:00:00 2001 From: Sebastian Proost Date: Thu, 19 Feb 2026 11:33:53 +0100 Subject: [PATCH 01/11] Add comprehensive test suite for lorepy plotting functions (#15) * Improve unit test suite with comprehensive function-level testing --- tests/conftest.py | 128 ++++++ tests/test_plot.py | 635 +++++++++++++++++++++++------ tests/test_uncertainty.py | 832 +++++++++++++++++++++++++++----------- 3 files changed, 1248 insertions(+), 347 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6a18bc4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,128 @@ +""" +Shared test fixtures for lorepy tests. +""" + +import numpy as np +import pandas as pd +import pytest +from sklearn.linear_model import LogisticRegression + + +@pytest.fixture +def random_seed(): + """Set a fixed random seed for reproducibility.""" + np.random.seed(42) + return 42 + + +@pytest.fixture +def binary_sample_data(random_seed): + """ + Create sample data with two classes for binary classification. + Class 0 tends to have lower x values, class 1 tends to have higher x values. + """ + X = np.concatenate([np.random.randint(0, 10, 50), np.random.randint(2, 12, 50)]) + y = [0] * 50 + [1] * 50 + z = X + np.random.randn(100) * 0.5 # Confounder correlated with x + return pd.DataFrame({"x": X.astype(float), "y": y, "z": z}) + + +@pytest.fixture +def multiclass_sample_data(random_seed): + """ + Create sample data with three classes for multi-class classification. + """ + X = np.concatenate( + [ + np.random.randint(0, 5, 30), + np.random.randint(3, 8, 30), + np.random.randint(6, 12, 30), + ] + ) + y = [0] * 30 + [1] * 30 + [2] * 30 + z = X + np.random.randn(90) * 0.3 + return pd.DataFrame({"x": X.astype(float), "y": y, "z": z}) + + +@pytest.fixture +def data_with_nan(random_seed): + """Create sample data with NaN values.""" + X = np.concatenate([np.random.randint(0, 10, 50), np.random.randint(2, 12, 50)]) + y = [0] * 50 + [1] * 50 + z = X.astype(float) + + # Introduce NaN values + X = X.astype(float) + X[5] = np.nan + X[15] = np.nan + y[25] = np.nan # This will become float due to NaN + + df = pd.DataFrame({"x": X, "y": y, "z": z}) + df.loc[25, "y"] = np.nan # Set after creation to avoid type issues + return df + + +@pytest.fixture +def small_deterministic_data(): + """Small, deterministic dataset for precise testing.""" + return pd.DataFrame( + { + "x": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "y": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + "z": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + } + ) + + +@pytest.fixture +def fitted_logistic_model(small_deterministic_data): + """A fitted logistic regression model on small deterministic data.""" + X = small_deterministic_data[["x"]].values + y = small_deterministic_data["y"].values + lg = LogisticRegression() + lg.fit(X, y) + return lg, X, y + + +@pytest.fixture +def fitted_multiclass_model(multiclass_sample_data): + """A fitted logistic regression model for multi-class classification.""" + X = multiclass_sample_data[["x"]].values + y = multiclass_sample_data["y"].values + lg = LogisticRegression(max_iter=200) + lg.fit(X, y) + return lg, X, y + + +@pytest.fixture +def single_class_data(): + """Data with only one class - should cause issues.""" + return pd.DataFrame( + { + "x": [1.0, 2.0, 3.0, 4.0, 5.0], + "y": [0, 0, 0, 0, 0], + "z": [0.1, 0.2, 0.3, 0.4, 0.5], + } + ) + + +@pytest.fixture +def empty_dataframe(): + """Empty DataFrame for edge case testing.""" + return pd.DataFrame({"x": [], "y": [], "z": []}) + + +@pytest.fixture +def string_class_labels(random_seed): + """Data with string class labels instead of integers.""" + X = np.concatenate([np.random.randint(0, 10, 50), np.random.randint(2, 12, 50)]) + y = ["class_a"] * 50 + ["class_b"] * 50 + return pd.DataFrame({"x": X.astype(float), "y": y}) + + +@pytest.fixture +def custom_colormap(): + """Custom colormap for testing uncertainty_plot.""" + from matplotlib.colors import ListedColormap + + return ListedColormap(["red", "green", "blue"]) diff --git a/tests/test_plot.py b/tests/test_plot.py index 52092a1..212a5f9 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,129 +1,524 @@ +""" +Comprehensive tests for lorepy.lorepy module. + +Tests cover: +- _prepare_data: data preparation and validation +- _get_area_df: probability area calculation +- _get_dots_df: scatter dot positioning +- loreplot: main plotting function +""" + import matplotlib.pyplot as plt import numpy as np import pandas as pd -from lorepy.lorepy import _get_area_df, _get_dots_df, loreplot +import pytest from pandas import DataFrame from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC -import pytest +from lorepy.lorepy import _get_area_df, _get_dots_df, _prepare_data, loreplot + +# ============================================================================= +# Tests for _prepare_data +# ============================================================================= + + +class TestPrepareData: + """Tests for the _prepare_data function.""" + + def test_basic_preparation(self, small_deterministic_data): + """Test basic data preparation without confounders.""" + X_reg, y_reg, x_range = _prepare_data(small_deterministic_data, "x", "y", []) + + assert isinstance(X_reg, np.ndarray) + assert isinstance(y_reg, np.ndarray) + assert isinstance(x_range, tuple) + assert len(x_range) == 2 + + # Check shapes + assert X_reg.shape == (10, 1) + assert y_reg.shape == (10,) + + # Check x_range + assert x_range[0] == 1.0 + assert x_range[1] == 10.0 + + def test_with_confounders(self, small_deterministic_data): + """Test data preparation with confounders.""" + confounders = [("z", 1.0)] + X_reg, y_reg, x_range = _prepare_data( + small_deterministic_data, "x", "y", confounders + ) + + # X should now have 2 columns: x and z + assert X_reg.shape == (10, 2) + # First column should be x (the main feature) + np.testing.assert_array_equal(X_reg[:, 0], small_deterministic_data["x"].values) + # Second column should be z (the confounder) + np.testing.assert_array_equal(X_reg[:, 1], small_deterministic_data["z"].values) + + def test_nan_removal(self, data_with_nan): + """Test that NaN values are properly removed.""" + X_reg, y_reg, x_range = _prepare_data(data_with_nan, "x", "y", []) + + # Should have fewer rows than original due to NaN removal + assert len(y_reg) < len(data_with_nan) + + # Should have no NaN values + assert not np.any(np.isnan(X_reg)) + assert not np.any(pd.isna(y_reg)) + + def test_x_range_calculation(self, binary_sample_data): + """Test that x_range is correctly calculated from data.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + expected_min = binary_sample_data["x"].min() + expected_max = binary_sample_data["x"].max() + + assert x_range[0] == expected_min + assert x_range[1] == expected_max + + def test_multiple_confounders(self, binary_sample_data): + """Test data preparation with multiple confounders.""" + # Add another confounder column + binary_sample_data["w"] = binary_sample_data["x"] * 2 + + confounders = [("z", 5.0), ("w", 10.0)] + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", confounders) + + # X should have 3 columns: x, z, w + assert X_reg.shape[1] == 3 + + def test_preserves_data_order(self, small_deterministic_data): + """Test that data order is preserved (x first, then confounders).""" + confounders = [("z", 1.0)] + X_reg, y_reg, x_range = _prepare_data( + small_deterministic_data, "x", "y", confounders + ) + + # First column must be x for compatibility with _get_feature_importance + np.testing.assert_array_equal(X_reg[:, 0], small_deterministic_data["x"].values) + + +# ============================================================================= +# Tests for _get_area_df +# ============================================================================= + + +class TestGetAreaDf: + """Tests for the _get_area_df function.""" + + def test_basic_output_structure(self, fitted_logistic_model): + """Test basic output structure of _get_area_df.""" + lg, X, y = fitted_logistic_model + x_range = (X.min(), X.max()) + + area_df = _get_area_df(lg, "x", x_range) + + assert isinstance(area_df, DataFrame) + # Should have 200 rows (default num points) + assert len(area_df) == 200 + # Index should be named after x_feature + assert area_df.index.name == "x" + # Should have columns for each class + assert 0 in area_df.columns + assert 1 in area_df.columns + + def test_probabilities_sum_to_one(self, fitted_logistic_model): + """Test that probabilities sum to 1 at each x point.""" + lg, X, y = fitted_logistic_model + x_range = (X.min(), X.max()) + + area_df = _get_area_df(lg, "x", x_range) + + # Sum of probabilities should be 1 for each row + row_sums = area_df.sum(axis=1) + np.testing.assert_array_almost_equal(row_sums.values, np.ones(200)) + + def test_probabilities_in_valid_range(self, fitted_logistic_model): + """Test that all probabilities are between 0 and 1.""" + lg, X, y = fitted_logistic_model + x_range = (X.min(), X.max()) + + area_df = _get_area_df(lg, "x", x_range) + + assert (area_df.values >= 0).all() + assert (area_df.values <= 1).all() + + def test_x_range_respected(self, fitted_logistic_model): + """Test that the x_range parameter is respected.""" + lg, X, y = fitted_logistic_model + custom_range = (2.0, 8.0) + + area_df = _get_area_df(lg, "x", custom_range) + + assert area_df.index[0] == custom_range[0] + assert area_df.index[-1] == custom_range[1] + + def test_with_confounders(self, binary_sample_data): + """Test _get_area_df with confounders.""" + # Fit a model with confounders + X_reg = binary_sample_data[["x", "z"]].values + y_reg = binary_sample_data["y"].values + lg = LogisticRegression() + lg.fit(X_reg, y_reg) + + confounders = [("z", 5.0)] + x_range = (0.0, 12.0) + + area_df = _get_area_df(lg, "x", x_range, confounders=confounders) + + # Should still have valid probabilities + assert len(area_df) == 200 + row_sums = area_df.sum(axis=1) + np.testing.assert_array_almost_equal(row_sums.values, np.ones(200)) + + def test_multiclass_output(self, fitted_multiclass_model): + """Test _get_area_df with multi-class classification.""" + lg, X, y = fitted_multiclass_model + x_range = (X.min(), X.max()) + + area_df = _get_area_df(lg, "x", x_range) + + # Should have 3 class columns + assert 0 in area_df.columns + assert 1 in area_df.columns + assert 2 in area_df.columns + + # Probabilities should still sum to 1 + row_sums = area_df.sum(axis=1) + np.testing.assert_array_almost_equal(row_sums.values, np.ones(200)) + + def test_monotonic_probability_trend(self, fitted_logistic_model): + """Test that probabilities show expected monotonic trend for clear separation.""" + lg, X, y = fitted_logistic_model + x_range = (1.0, 10.0) + + area_df = _get_area_df(lg, "x", x_range) + + # For our deterministic data (0s at low x, 1s at high x), + # P(class=1) should generally increase with x + class_1_probs = area_df[1].values + # Check that probability at end is higher than at start + assert class_1_probs[-1] > class_1_probs[0] + + +# ============================================================================= +# Tests for _get_dots_df +# ============================================================================= + + +class TestGetDotsDf: + """Tests for the _get_dots_df function.""" + + def test_basic_output_structure(self, fitted_logistic_model): + """Test basic output structure of _get_dots_df.""" + lg, X, y = fitted_logistic_model + + dots_df = _get_dots_df(X, y, lg, "y") + + assert isinstance(dots_df, DataFrame) + assert len(dots_df) == len(X) + assert "x" in dots_df.columns + assert "y" in dots_df.columns + assert "y" in dots_df.columns # The y_feature column + + def test_y_coordinates_in_valid_range(self, fitted_logistic_model): + """Test that y coordinates are between 0 and 1.""" + lg, X, y = fitted_logistic_model + + dots_df = _get_dots_df(X, y, lg, "y_label") + + assert (dots_df["y"].values >= 0).all() + assert (dots_df["y"].values <= 1).all() + + def test_x_coordinates_match_input(self, fitted_logistic_model): + """Test that x coordinates match input data (without jitter).""" + lg, X, y = fitted_logistic_model + + dots_df = _get_dots_df(X, y, lg, "y_label", jitter=0) + + np.testing.assert_array_equal(dots_df["x"].values, X.flatten()) + + def test_jitter_modifies_x_coordinates(self, fitted_logistic_model, random_seed): + """Test that jitter modifies x coordinates.""" + lg, X, y = fitted_logistic_model + jitter_amount = 0.5 + + dots_df = _get_dots_df(X.copy(), y, lg, "y_label", jitter=jitter_amount) + + # With jitter, x values should be different from original + # Note: This modifies X in place, so we use a copy + differences = np.abs(dots_df["x"].values - X.flatten()) + # At least some differences should be non-zero + assert np.any(differences > 0) + # All differences should be within jitter range + assert np.all(differences <= jitter_amount) + + def test_y_feature_column_values(self, fitted_logistic_model): + """Test that y_feature column contains correct class labels.""" + lg, X, y = fitted_logistic_model + + dots_df = _get_dots_df(X, y, lg, "class_label") + + assert "class_label" in dots_df.columns + np.testing.assert_array_equal(dots_df["class_label"].values, y) + + def test_y_within_probability_bands(self, fitted_logistic_model): + """Test that y coordinates fall within the probability band for their class.""" + lg, X, y = fitted_logistic_model + + dots_df = _get_dots_df(X, y, lg, "y_label") + + for idx, row in dots_df.iterrows(): + x_val = np.array([[row["x"]]]) + proba = lg.predict_proba(x_val)[0] + class_idx = list(lg.classes_).index(row["y_label"]) + + # Calculate expected band + min_val = sum(proba[:class_idx]) + max_val = sum(proba[: class_idx + 1]) + + # y should be within the band (with some tolerance for margin) + assert row["y"] >= min_val - 0.01 + assert row["y"] <= max_val + 0.01 + + def test_multiclass_dots(self, fitted_multiclass_model): + """Test _get_dots_df with multi-class classification.""" + lg, X, y = fitted_multiclass_model + + dots_df = _get_dots_df(X, y, lg, "class") + + assert len(dots_df) == len(X) + assert set(dots_df["class"].unique()) == {0, 1, 2} + + +# ============================================================================= +# Tests for loreplot +# ============================================================================= + + +class TestLoreplot: + """Tests for the main loreplot function.""" + + def test_creates_plot_without_ax(self, binary_sample_data): + """Test that loreplot creates a plot when no ax is provided.""" + # Should not raise + loreplot(binary_sample_data, "x", "y") + plt.close() + + def test_uses_provided_ax(self, binary_sample_data): + """Test that loreplot uses provided axes.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax) + + assert ax.get_xlabel() == "x" + plt.close() + + def test_axis_limits(self, binary_sample_data): + """Test that axis limits are set correctly.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax) + + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + # y should be 0 to 1 (probability) + assert ylim == (0, 1) + # x should span the data range + assert xlim[0] <= binary_sample_data["x"].min() + assert xlim[1] >= binary_sample_data["x"].max() + + plt.close() + + def test_custom_x_range(self, binary_sample_data): + """Test that custom x_range is respected.""" + fig, ax = plt.subplots() + custom_range = (0, 20) + loreplot(binary_sample_data, "x", "y", ax=ax, x_range=custom_range) + + xlim = ax.get_xlim() + assert xlim == custom_range + + plt.close() + + def test_add_dots_true(self, binary_sample_data): + """Test that dots are added when add_dots=True.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax, add_dots=True) + + # Check that scatter points were added (collections should have dots) + # The scatter creates a PathCollection + collections = ax.collections + assert len(collections) > 0 + + plt.close() + + def test_add_dots_false(self, binary_sample_data): + """Test that no dots are added when add_dots=False.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax, add_dots=False) + + # No scatter collections should be added + # Note: area plot might add some collections, but scatter specifically won't + scatter_collections = [ + c + for c in ax.collections + if hasattr(c, "get_offsets") and len(c.get_offsets()) > 0 + ] + # When add_dots=False, there should be no scatter points + for c in scatter_collections: + offsets = c.get_offsets() + # If there are offsets, they shouldn't be the data points + if len(offsets) > 0: + # This is from the area plot, not scatter + pass + + plt.close() + + def test_with_jitter(self, binary_sample_data, random_seed): + """Test that jitter parameter works.""" + fig, ax = plt.subplots() + # Should not raise + loreplot(binary_sample_data, "x", "y", ax=ax, jitter=0.1) + + plt.close() + + def test_with_confounders(self, binary_sample_data): + """Test loreplot with confounders.""" + fig, ax = plt.subplots() + # Should not raise + loreplot(binary_sample_data, "x", "y", ax=ax, confounders=[("z", 5.0)]) + + plt.close() + + def test_confounders_disable_dots(self, binary_sample_data): + """Test that dots are not added when confounders are specified.""" + fig, ax = plt.subplots() + loreplot( + binary_sample_data, + "x", + "y", + ax=ax, + add_dots=True, # Request dots + confounders=[("z", 5.0)], # But confounders should prevent them + ) + + # With confounders, dots should not be added even if add_dots=True + # Check there are no scatter points with many offsets + scatter_with_data = [ + c + for c in ax.collections + if hasattr(c, "get_offsets") + and len(c.get_offsets()) == len(binary_sample_data) + ] + assert len(scatter_with_data) == 0 + + plt.close() + + def test_custom_classifier(self, binary_sample_data): + """Test loreplot with a custom classifier.""" + fig, ax = plt.subplots() + svc = SVC(probability=True) + # Should not raise + loreplot(binary_sample_data, "x", "y", ax=ax, clf=svc) + + plt.close() + + def test_custom_colors(self, binary_sample_data): + """Test loreplot with custom colors.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax, color=["red", "blue"]) + + plt.close() + + def test_scatter_kws(self, binary_sample_data): + """Test that scatter_kws are passed through.""" + fig, ax = plt.subplots() + loreplot( + binary_sample_data, "x", "y", ax=ax, scatter_kws={"s": 100, "marker": "^"} + ) + + plt.close() + + def test_kwargs_passed_to_area_plot(self, binary_sample_data): + """Test that kwargs are passed to the area plot.""" + fig, ax = plt.subplots() + loreplot(binary_sample_data, "x", "y", ax=ax, alpha=0.5, linestyle="-") + + plt.close() + + def test_string_class_labels(self, string_class_labels): + """Test loreplot with string class labels.""" + fig, ax = plt.subplots() + loreplot(string_class_labels, "x", "y", ax=ax) + + plt.close() + + def test_multiclass(self, multiclass_sample_data): + """Test loreplot with multi-class classification.""" + fig, ax = plt.subplots() + loreplot(multiclass_sample_data, "x", "y", ax=ax) + + plt.close() + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_single_class_raises_error(self, single_class_data): + """Test that single class data raises an appropriate error.""" + fig, ax = plt.subplots() + # LogisticRegression should fail or warn with single class + with pytest.raises(ValueError): + loreplot(single_class_data, "x", "y", ax=ax) + plt.close() + + def test_empty_dataframe_raises_error(self, empty_dataframe): + """Test that empty DataFrame raises an error.""" + fig, ax = plt.subplots() + with pytest.raises((ValueError, IndexError)): + loreplot(empty_dataframe, "x", "y", ax=ax) + plt.close() + + def test_missing_column_raises_error(self, binary_sample_data): + """Test that missing column raises KeyError.""" + fig, ax = plt.subplots() + with pytest.raises(KeyError): + loreplot(binary_sample_data, "nonexistent", "y", ax=ax) + plt.close() + + def test_nan_handling(self, data_with_nan): + """Test that NaN values are handled properly.""" + fig, ax = plt.subplots() + # Should not raise - NaN rows should be dropped + loreplot(data_with_nan, "x", "y", ax=ax) + plt.close() + + def test_inverted_x_range(self, binary_sample_data): + """Test behavior with inverted x_range (max < min).""" + fig, ax = plt.subplots() + # Note: This might create a valid plot with inverted axis + loreplot(binary_sample_data, "x", "y", ax=ax, x_range=(10, 0)) + plt.close() + def test_zero_width_x_range(self, binary_sample_data): + """Test behavior when x_range has zero width. -@pytest.fixture -def sample_data(): - X = np.concatenate([np.random.randint(0, 10, 50), np.random.randint(2, 12, 50)]) - y = [0] * 50 + [1] * 50 - z = X - return pd.DataFrame({"x": X, "y": y, "z": z}) - - -@pytest.fixture -def logistic_regression_model(): - X_reg = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) - y_reg = np.array([0, 1, 0, 1, 1]) - lg = LogisticRegression() - lg.fit(X_reg, y_reg) - return X_reg, y_reg, lg - - -# Test case for loreplot with default parameters -def test_loreplot_default(sample_data): - loreplot(sample_data, "x", "y") # first test without specifying the axis - - fig, ax = plt.subplots() - loreplot(sample_data, "x", "y", ax=ax) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Test case for loreplot with jitter -def test_loreplot_jitter(sample_data): - loreplot(sample_data, "x", "y") # first test without specifying the axis - - fig, ax = plt.subplots() - loreplot(sample_data, "x", "y", ax=ax, jitter=0.05) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Test case for loreplot with confounder -def test_loreplot_confounder(sample_data): - loreplot( - sample_data, "x", "y", confounders=[("z", 1)] - ) # first test without specifying the axis - - fig, ax = plt.subplots() - loreplot(sample_data, "x", "y", ax=ax) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Test case for loreplot with custom clf -def test_loreplot_custom_clf(sample_data): - svc = SVC(probability=True) - loreplot(sample_data, "x", "y", clf=svc) - - fig, ax = plt.subplots() - loreplot(sample_data, "x", "y", ax=ax) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Test case for loreplot with custom parameters -def test_loreplot_custom(sample_data): - fig, ax = plt.subplots() - loreplot( - sample_data, - "x", - "y", - add_dots=False, - x_range=(0, 5), - ax=ax, - color=["r", "b"], - linestyle="-", - ) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Test case for loreplot with add_dots=True -def test_loreplot_with_dots(sample_data): - fig, ax = plt.subplots() - loreplot(sample_data, "x", "y", add_dots=True, ax=ax) - assert ax.get_title() == "" - assert ax.get_xlabel() == "x" - assert ax.get_ylabel() == "" - - -# Sample data for testing internal functions -X_reg = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) -y_reg = np.array([0, 1, 0, 1, 1]) -lg = LogisticRegression() -lg.fit(X_reg, y_reg) - - -# Test case for _get_dots_df -def test_get_dots_df(): - dots_df = _get_dots_df(X_reg, y_reg, lg, "y") - assert isinstance(dots_df, DataFrame) - assert "x" in dots_df.columns - assert "y" in dots_df.columns - assert "y_feature" not in dots_df.columns - assert len(dots_df) == len(X_reg) - - -# Test case for _get_area_df -def test_get_area_df(): - area_df = _get_area_df(lg, "x", (X_reg.min(), X_reg.max())) - assert isinstance(area_df, DataFrame) - assert "x" not in area_df.columns - assert 0 in area_df.columns - assert 1 in area_df.columns - assert len(area_df) == 200 - assert area_df.index[0] == X_reg.min() - assert area_df.index[-1] == X_reg.max() + Note: The current implementation does not validate x_range and will + create a degenerate plot where all 200 points are at the same x value. + This test documents the current behavior rather than asserting it's correct. + """ + fig, ax = plt.subplots() + # Current implementation doesn't raise an error, but creates a degenerate plot + # This may be considered a bug or missing input validation + loreplot(binary_sample_data, "x", "y", ax=ax, x_range=(5, 5)) + # Matplotlib may expand xlim slightly, but should be centered around 5 + xlim = ax.get_xlim() + assert abs((xlim[0] + xlim[1]) / 2 - 5) < 0.5 # Center is approximately 5 + plt.close() diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index 8b0d5b0..f52fb47 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -1,254 +1,632 @@ +""" +Comprehensive tests for lorepy.uncertainty module. + +Tests cover: +- _get_uncertainty_data: uncertainty estimation via resampling/jackknife +- _get_feature_importance: feature importance calculation +- uncertainty_plot: main uncertainty visualization function +- feature_importance: public API for feature importance +""" + import numpy as np import pandas as pd import pytest import warnings -from lorepy import uncertainty_plot, feature_importance -from lorepy.uncertainty import _get_feature_importance -from lorepy.lorepy import _prepare_data -from matplotlib.colors import ListedColormap from matplotlib import pyplot as plt from sklearn.svm import SVC +from lorepy import uncertainty_plot, feature_importance +from lorepy.uncertainty import _get_uncertainty_data, _get_feature_importance +from lorepy.lorepy import _prepare_data -@pytest.fixture -def sample_data(): - X = np.concatenate([np.random.randint(0, 10, 50), np.random.randint(2, 12, 50)]) - y = [0] * 50 + [1] * 50 - z = X - return pd.DataFrame({"x": X, "y": y, "z": z}) - - -@pytest.fixture -def custom_colormap(): - return ListedColormap(["red", "green", "blue"]) - - -# Test case for lorepy's uncertainty plot with default parameters -def test_uncertainty_default(sample_data): - fig, axs = uncertainty_plot(sample_data, "x", "y") # first test with default params - - assert len(axs) == 2 - assert axs[0].get_title() == "0" - assert axs[0].get_xlabel() == "x" - assert axs[0].get_ylabel() == "" - - -# Test case for lorepy's uncertainty plot with alternative parameters -def test_uncertainty_alternative(sample_data, custom_colormap): - svc = SVC(probability=True) - fig, axs = uncertainty_plot( - sample_data, - "x", - "y", - mode="jackknife", - x_range=(5, 40), - colormap=custom_colormap, - clf=svc, - ) +# ============================================================================= +# Tests for _get_uncertainty_data +# ============================================================================= - assert len(axs) == 2 - assert axs[0].get_title() == "0" - assert axs[0].get_xlabel() == "x" - assert axs[0].get_ylabel() == "" - -def test_get_uncertainty_confounder(sample_data): - fig, axs = uncertainty_plot( - sample_data, "x", "y", confounders=[("z", 5)] - ) # first test with default params - - assert len(axs) == 2 - assert axs[0].get_title() == "0" - assert axs[0].get_xlabel() == "x" - assert axs[0].get_ylabel() == "" - - -# Test error handling when an unsupported mode is selected -def test_uncertainty_incorrect_mode(sample_data): - with pytest.raises(NotImplementedError): - assert uncertainty_plot(sample_data, "x", "y", mode="fail") - - -def test_uncertainty_with_existing_ax(sample_data): - fig, ax = plt.subplots(1, 2) # Create 2 axes manually - returned_fig, returned_axs = uncertainty_plot(sample_data, "x", "y", ax=ax) - - assert returned_fig is not None - assert returned_axs[0] == ax[0] - assert returned_axs[1] == ax[1] - assert len(returned_axs) == 2 - assert returned_axs[0].get_title() == "0" - assert returned_axs[0].get_xlabel() == "x" - - -def test_uncertainty_incorrect_ax_length(sample_data): - fig, ax = plt.subplots(1, 1) # Only one axis created, but we expect two - with pytest.raises(AssertionError): - uncertainty_plot(sample_data, "x", "y", ax=[ax]) - - -# Test case for feature importance function with default parameters -def test_feature_importance_default(sample_data): - X_reg, y_reg, _ = _prepare_data(sample_data, "x", "y", []) - result = _get_feature_importance("x", X_reg, y_reg, iterations=10) - - # Check that result is a dictionary with expected keys - expected_keys = [ - "feature", - "mean_importance", - "std_importance", - "importance_95ci_low", - "importance_95ci_high", - "proportion_positive", - "proportion_negative", - "p_value", - "iterations", - "mode", - "interpretation", - ] - - for key in expected_keys: - assert key in result - - # Check basic properties - assert result["feature"] == "x" - assert result["iterations"] == 10 - assert result["mode"] == "resample" - assert isinstance(result["mean_importance"], float) - assert isinstance(result["p_value"], float) - assert 0 <= result["p_value"] <= 1 - assert 0 <= result["proportion_positive"] <= 1 - assert 0 <= result["proportion_negative"] <= 1 - # Proportions should sum to <= 1 (the remainder are zeros) - assert result["proportion_positive"] + result["proportion_negative"] <= 1 - - -# Test case for feature importance with different modes and classifiers -def test_feature_importance_alternative(sample_data): - X_reg, y_reg, _ = _prepare_data(sample_data, "x", "y", []) - svc = SVC(probability=True) - - result = _get_feature_importance( - "x", X_reg, y_reg, mode="jackknife", iterations=10, clf=svc - ) - - assert result["mode"] == "jackknife" - assert result["iterations"] == 10 - assert isinstance(result["mean_importance"], float) - - -# Test error handling for unsupported mode -def test_feature_importance_incorrect_mode(sample_data): - X_reg, y_reg, _ = _prepare_data(sample_data, "x", "y", []) - - with pytest.raises(NotImplementedError): - _get_feature_importance("x", X_reg, y_reg, mode="invalid_mode") - - -# Test case for public feature_importance function -def test_public_feature_importance(sample_data): - # Test the public API function - result = feature_importance(sample_data, x="x", y="y", iterations=10) - - # Should have same output format as internal function - expected_keys = [ - "feature", - "mean_importance", - "std_importance", - "importance_95ci_low", - "importance_95ci_high", - "proportion_positive", - "proportion_negative", - "p_value", - "iterations", - "mode", - "interpretation", - ] - - for key in expected_keys: - assert key in result - - assert result["feature"] == "x" - assert result["iterations"] == 10 - - -# Test public function with confounders and different classifier -def test_public_feature_importance_advanced(sample_data): - svc = SVC(probability=True) - - result = feature_importance( - sample_data, - x="x", - y="y", - confounders=[("z", 5)], - clf=svc, - mode="jackknife", - iterations=10, - ) - - assert result["feature"] == "x" - assert result["mode"] == "jackknife" - assert isinstance(result["mean_importance"], float) - - -# Test warning for bootstrap mode (train/validation overlap) -def test_feature_importance_bootstrap_warning(sample_data): - with pytest.warns( - UserWarning, match="Bootstrap resampling mode uses the same data" - ): - result = feature_importance( - sample_data, x="x", y="y", mode="resample", iterations=5 +class TestGetUncertaintyData: + """Tests for the _get_uncertainty_data function.""" + + def test_basic_output_structure_resample(self, binary_sample_data): + """Test basic output structure with resample mode.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + output, long_df = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=10 + ) + + # Check output DataFrame structure + assert isinstance(output, pd.DataFrame) + assert "x" in output.columns + assert "variable" in output.columns + assert "min" in output.columns + assert "mean" in output.columns + assert "max" in output.columns + assert "low_95" in output.columns + assert "high_95" in output.columns + assert "low_50" in output.columns + assert "high_50" in output.columns + + # Check long_df structure + assert isinstance(long_df, pd.DataFrame) + + def test_basic_output_structure_jackknife(self, binary_sample_data): + """Test basic output structure with jackknife mode.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + output, long_df = _get_uncertainty_data( + "x", + X_reg, + y_reg, + x_range, + mode="jackknife", + jackknife_fraction=0.8, + iterations=10, + ) + + assert isinstance(output, pd.DataFrame) + assert "mean" in output.columns + + def test_uncertainty_bounds_ordering(self, binary_sample_data): + """Test that uncertainty bounds are properly ordered.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + output, _ = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=50 ) + + # For each row, bounds should be ordered: min <= low_95 <= low_50 <= mean <= high_50 <= high_95 <= max + # Note: Due to bootstrap variability, mean might not be exactly between low_50 and high_50 + # but min/max bounds should always hold + assert (output["min"] <= output["max"]).all() + assert (output["low_95"] <= output["high_95"]).all() + assert (output["low_50"] <= output["high_50"]).all() + + def test_probability_values_in_range(self, binary_sample_data): + """Test that all probability values are between 0 and 1.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + output, _ = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=10 + ) + + for col in ["min", "mean", "max", "low_95", "high_95", "low_50", "high_50"]: + assert (output[col] >= 0).all(), f"{col} has values < 0" + assert (output[col] <= 1).all(), f"{col} has values > 1" + + def test_iterations_parameter(self, binary_sample_data): + """Test that iterations parameter affects output variability.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + # With very few iterations, expect wider intervals + output_few, long_few = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=5 + ) + + # Long_df should have iterations * 200 (num points) * num_classes rows + # Actually it's melted, so the structure depends on implementation + assert len(long_few) > 0 + + def test_with_confounders(self, binary_sample_data): + """Test _get_uncertainty_data with confounders.""" + confounders = [("z", 5.0)] + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", confounders) + + output, _ = _get_uncertainty_data( + "x", + X_reg, + y_reg, + x_range, + mode="resample", + iterations=10, + confounders=confounders, + ) + + assert isinstance(output, pd.DataFrame) + assert "mean" in output.columns + + def test_custom_classifier(self, binary_sample_data): + """Test _get_uncertainty_data with custom classifier.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + svc = SVC(probability=True) + + output, _ = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=10, clf=svc + ) + + assert isinstance(output, pd.DataFrame) + + def test_invalid_mode_raises_error(self, binary_sample_data): + """Test that invalid mode raises NotImplementedError.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + with pytest.raises(NotImplementedError): + _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="invalid_mode", iterations=10 + ) + + def test_output_has_all_categories(self, binary_sample_data): + """Test that output includes all class categories.""" + X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) + + output, _ = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=10 + ) + + categories = output["variable"].unique() + assert 0 in categories + assert 1 in categories + + def test_multiclass_uncertainty(self, multiclass_sample_data): + """Test _get_uncertainty_data with multi-class classification.""" + X_reg, y_reg, x_range = _prepare_data(multiclass_sample_data, "x", "y", []) + + output, _ = _get_uncertainty_data( + "x", X_reg, y_reg, x_range, mode="resample", iterations=10 + ) + + categories = output["variable"].unique() + assert len(categories) == 3 + + +# ============================================================================= +# Tests for _get_feature_importance +# ============================================================================= + + +class TestGetFeatureImportance: + """Tests for the _get_feature_importance function.""" + + def test_basic_output_structure(self, binary_sample_data): + """Test basic output structure of feature importance.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10) + + expected_keys = [ + "feature", + "mean_importance", + "std_importance", + "importance_95ci_low", + "importance_95ci_high", + "proportion_positive", + "proportion_negative", + "p_value", + "iterations", + "mode", + "interpretation", + ] + + for key in expected_keys: + assert key in result, f"Missing key: {key}" + + def test_feature_name_preserved(self, binary_sample_data): + """Test that feature name is preserved in output.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("my_feature", X_reg, y_reg, iterations=10) + + assert result["feature"] == "my_feature" + + def test_iterations_parameter(self, binary_sample_data): + """Test that iterations parameter is reflected in output.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=25) + + assert result["iterations"] == 25 + + def test_mode_resample(self, binary_sample_data): + """Test resample mode.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + with pytest.warns(UserWarning, match="Bootstrap resampling mode"): + result = _get_feature_importance( + "x", X_reg, y_reg, mode="resample", iterations=10 + ) + assert result["mode"] == "resample" + def test_mode_jackknife(self, binary_sample_data): + """Test jackknife mode.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance( + "x", X_reg, y_reg, mode="jackknife", iterations=10 + ) + + assert result["mode"] == "jackknife" + + def test_invalid_mode_raises_error(self, binary_sample_data): + """Test that invalid mode raises NotImplementedError.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + with pytest.raises(NotImplementedError): + _get_feature_importance( + "x", X_reg, y_reg, mode="invalid_mode", iterations=10 + ) + + def test_p_value_in_valid_range(self, binary_sample_data): + """Test that p-value is between 0 and 1.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10) + + assert 0 <= result["p_value"] <= 1 + + def test_proportions_sum_valid(self, binary_sample_data): + """Test that positive + negative proportions <= 1.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10) + + assert result["proportion_positive"] + result["proportion_negative"] <= 1 + assert result["proportion_positive"] >= 0 + assert result["proportion_negative"] >= 0 + + def test_confidence_interval_ordering(self, binary_sample_data): + """Test that confidence interval bounds are ordered correctly.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) -# Test warning for small validation sets in jackknife mode -def test_feature_importance_small_validation_warning(): - # Create a small dataset to trigger the warning - small_data = pd.DataFrame( - { - "x": np.random.randn(15), - "y": np.random.choice([0, 1], 15), - "z": np.random.randn(15), - } - ) + result = _get_feature_importance("x", X_reg, y_reg, iterations=50) + + assert result["importance_95ci_low"] <= result["importance_95ci_high"] + + def test_interpretation_string_format(self, binary_sample_data): + """Test that interpretation is a properly formatted string.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10) + + assert isinstance(result["interpretation"], str) + assert "Feature importance" in result["interpretation"] + + def test_custom_classifier(self, binary_sample_data): + """Test with custom classifier.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + svc = SVC(probability=True) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10, clf=svc) + + assert isinstance(result["mean_importance"], float) + + def test_jackknife_fraction_parameter(self, binary_sample_data): + """Test that jackknife_fraction parameter is used.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + # Different fractions should produce different results + result_80 = _get_feature_importance( + "x", X_reg, y_reg, mode="jackknife", jackknife_fraction=0.8, iterations=10 + ) - with pytest.warns(UserWarning, match="Jackknife validation set is small"): - # With jackknife_fraction=0.8, validation set will be 15 * 0.2 = 3 < 20 + result_50 = _get_feature_importance( + "x", X_reg, y_reg, mode="jackknife", jackknife_fraction=0.5, iterations=10 + ) + + # Both should produce valid results + assert isinstance(result_80["mean_importance"], float) + assert isinstance(result_50["mean_importance"], float) + + +# ============================================================================= +# Tests for uncertainty_plot +# ============================================================================= + + +class TestUncertaintyPlot: + """Tests for the main uncertainty_plot function.""" + + def test_basic_plot_creation(self, binary_sample_data): + """Test basic plot creation with default parameters.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + assert fig is not None + assert len(axs) == 2 # Two classes + plt.close() + + def test_axes_titles(self, binary_sample_data): + """Test that axes have correct titles.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + assert axs[0].get_title() == "0" + assert axs[1].get_title() == "1" + plt.close() + + def test_axes_labels(self, binary_sample_data): + """Test that axes have correct labels.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + assert axs[0].get_xlabel() == "x" + assert axs[1].get_xlabel() == "x" + plt.close() + + def test_axes_limits(self, binary_sample_data): + """Test that axes have correct limits.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + for ax in axs: + ylim = ax.get_ylim() + assert ylim == (0, 1) + plt.close() + + def test_custom_x_range(self, binary_sample_data): + """Test custom x_range parameter.""" + custom_range = (0, 20) + fig, axs = uncertainty_plot( + binary_sample_data, "x", "y", x_range=custom_range, iterations=10 + ) + + for ax in axs: + xlim = ax.get_xlim() + assert xlim == custom_range + plt.close() + + def test_jackknife_mode(self, binary_sample_data): + """Test jackknife mode.""" + fig, axs = uncertainty_plot( + binary_sample_data, "x", "y", mode="jackknife", iterations=10 + ) + + assert len(axs) == 2 + plt.close() + + def test_invalid_mode_raises_error(self, binary_sample_data): + """Test that invalid mode raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + uncertainty_plot( + binary_sample_data, "x", "y", mode="invalid_mode", iterations=10 + ) + + def test_with_confounders(self, binary_sample_data): + """Test plot with confounders.""" + fig, axs = uncertainty_plot( + binary_sample_data, "x", "y", confounders=[("z", 5.0)], iterations=10 + ) + + assert len(axs) == 2 + plt.close() + + def test_custom_colormap(self, binary_sample_data, custom_colormap): + """Test plot with custom colormap.""" + fig, axs = uncertainty_plot( + binary_sample_data, "x", "y", colormap=custom_colormap, iterations=10 + ) + + assert len(axs) == 2 + plt.close() + + def test_custom_classifier(self, binary_sample_data): + """Test plot with custom classifier.""" + svc = SVC(probability=True) + fig, axs = uncertainty_plot( + binary_sample_data, "x", "y", clf=svc, iterations=10 + ) + + assert len(axs) == 2 + plt.close() + + def test_existing_axes(self, binary_sample_data): + """Test plot with pre-existing axes.""" + fig, ax = plt.subplots(1, 2) + returned_fig, returned_axs = uncertainty_plot( + binary_sample_data, "x", "y", ax=ax, iterations=10 + ) + + assert returned_axs[0] == ax[0] + assert returned_axs[1] == ax[1] + plt.close() + + def test_wrong_number_of_axes_raises_error(self, binary_sample_data): + """Test that wrong number of axes raises AssertionError.""" + fig, ax = plt.subplots(1, 1) # Only one axis + + with pytest.raises(AssertionError): + uncertainty_plot(binary_sample_data, "x", "y", ax=[ax], iterations=10) + plt.close() + + def test_multiclass_creates_correct_number_of_axes(self, multiclass_sample_data): + """Test that multiclass data creates correct number of axes.""" + fig, axs = uncertainty_plot(multiclass_sample_data, "x", "y", iterations=10) + + assert len(axs) == 3 # Three classes + plt.close() + + def test_plot_has_fill_between(self, binary_sample_data): + """Test that plot includes fill_between elements.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + # Each axis should have collections from fill_between + for ax in axs: + assert len(ax.collections) > 0 + plt.close() + + def test_plot_has_line(self, binary_sample_data): + """Test that plot includes mean line.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=10) + + # Each axis should have at least one line (the mean) + for ax in axs: + assert len(ax.lines) > 0 + plt.close() + + +# ============================================================================= +# Tests for public feature_importance function +# ============================================================================= + + +class TestPublicFeatureImportance: + """Tests for the public feature_importance API.""" + + def test_basic_usage(self, binary_sample_data): + """Test basic usage of feature_importance.""" + result = feature_importance(binary_sample_data, x="x", y="y", iterations=10) + + assert result["feature"] == "x" + assert result["iterations"] == 10 + + def test_with_confounders(self, binary_sample_data): + """Test feature_importance with confounders.""" result = feature_importance( - small_data, - x="x", - y="y", - mode="jackknife", - jackknife_fraction=0.8, - iterations=5, + binary_sample_data, x="x", y="y", confounders=[("z", 5.0)], iterations=10 ) + + assert result["feature"] == "x" + + def test_with_custom_classifier(self, binary_sample_data): + """Test feature_importance with custom classifier.""" + svc = SVC(probability=True) + result = feature_importance( + binary_sample_data, x="x", y="y", clf=svc, iterations=10 + ) + + assert isinstance(result["mean_importance"], float) + + def test_jackknife_mode(self, binary_sample_data): + """Test feature_importance with jackknife mode.""" + result = feature_importance( + binary_sample_data, x="x", y="y", mode="jackknife", iterations=10 + ) + assert result["mode"] == "jackknife" + def test_resample_mode_warning(self, binary_sample_data): + """Test that resample mode issues a warning.""" + with pytest.warns(UserWarning, match="Bootstrap resampling mode"): + result = feature_importance( + binary_sample_data, x="x", y="y", mode="resample", iterations=10 + ) + + assert result["mode"] == "resample" + + def test_small_validation_set_warning(self): + """Test warning for small validation sets in jackknife mode.""" + small_data = pd.DataFrame( + { + "x": np.random.randn(15), + "y": np.random.choice([0, 1], 15), + } + ) -# Test no warning for adequate validation sets -def test_feature_importance_no_warning_adequate_validation(): - # Create a larger dataset that shouldn't trigger warnings - large_data = pd.DataFrame( - { - "x": np.random.randn(150), - "y": np.random.choice([0, 1], 150), - "z": np.random.randn(150), - } - ) - - # This should not trigger any warnings (150 * 0.2 = 30 validation samples >= 20) - with warnings.catch_warnings(): - warnings.simplefilter("error") # Turn warnings into errors - try: + with pytest.warns(UserWarning, match="Jackknife validation set is small"): result = feature_importance( - large_data, + small_data, x="x", y="y", mode="jackknife", jackknife_fraction=0.8, iterations=5, ) - assert result["mode"] == "jackknife" - except UserWarning: - pytest.fail("Unexpected warning for adequate validation set size") + + assert result["mode"] == "jackknife" + + def test_no_warning_adequate_validation(self): + """Test no warning for adequate validation sets.""" + np.random.seed(42) + large_data = pd.DataFrame( + { + "x": np.random.randn(150), + "y": np.random.choice([0, 1], 150), + } + ) + + # This should not trigger the small validation warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + try: + result = feature_importance( + large_data, + x="x", + y="y", + mode="jackknife", + jackknife_fraction=0.8, + iterations=5, + ) + assert result["mode"] == "jackknife" + except UserWarning as e: + if "small" in str(e).lower(): + pytest.fail("Unexpected small validation warning") + + def test_output_consistency(self, binary_sample_data): + """Test that public function output matches internal function.""" + # Get result from public function + public_result = feature_importance( + binary_sample_data, x="x", y="y", iterations=10 + ) + + # Both should have the same keys + expected_keys = [ + "feature", + "mean_importance", + "std_importance", + "importance_95ci_low", + "importance_95ci_high", + "proportion_positive", + "proportion_negative", + "p_value", + "iterations", + "mode", + "interpretation", + ] + + for key in expected_keys: + assert key in public_result + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestUncertaintyEdgeCases: + """Edge case tests for uncertainty module.""" + + def test_very_few_iterations(self, binary_sample_data): + """Test with minimum iterations.""" + fig, axs = uncertainty_plot(binary_sample_data, "x", "y", iterations=2) + + assert len(axs) == 2 + plt.close() + + def test_high_jackknife_fraction(self, binary_sample_data): + """Test with high jackknife fraction.""" + fig, axs = uncertainty_plot( + binary_sample_data, + "x", + "y", + mode="jackknife", + jackknife_fraction=0.95, + iterations=10, + ) + + assert len(axs) == 2 + plt.close() + + def test_low_jackknife_fraction(self, binary_sample_data): + """Test with low jackknife fraction.""" + fig, axs = uncertainty_plot( + binary_sample_data, + "x", + "y", + mode="jackknife", + jackknife_fraction=0.5, + iterations=10, + ) + + assert len(axs) == 2 + plt.close() + + def test_string_class_labels(self, string_class_labels): + """Test with string class labels.""" + fig, axs = uncertainty_plot(string_class_labels, "x", "y", iterations=10) + + assert len(axs) == 2 + plt.close() + + def test_nan_handling(self, data_with_nan): + """Test that NaN values are handled properly.""" + # Should not raise - NaN rows should be dropped by _prepare_data + fig, axs = uncertainty_plot(data_with_nan, "x", "y", iterations=10) + + assert len(axs) == 2 + plt.close() From c39257e3eb189b03ede680943aa2832e318a8894 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 08:20:35 +0000 Subject: [PATCH 02/11] improved resample method for feature importance --- src/lorepy/uncertainty.py | 46 +++++++++++++++++++-------------------- tests/test_uncertainty.py | 31 +++++++++++++------------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/src/lorepy/uncertainty.py b/src/lorepy/uncertainty.py index 5a4ddd8..d1f1d52 100644 --- a/src/lorepy/uncertainty.py +++ b/src/lorepy/uncertainty.py @@ -82,8 +82,9 @@ def _get_feature_importance( x: str, X_reg, y_reg, - mode="resample", + mode="jackknife", jackknife_fraction: float = 0.8, + resample_validation_fraction: float = 0.2, iterations: int = 100, clf=None, ): @@ -96,6 +97,7 @@ def _get_feature_importance( :param y_reg: Target variable. :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "jackknife". :param jackknife_fraction: Fraction of data to keep in each jackknife iteration (only used if mode="jackknife"). + :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). :param iterations: Number of resampling or jackknife iterations. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, and significance metrics. @@ -103,40 +105,33 @@ def _get_feature_importance( importance_scores = [] - # Issue warnings about statistical considerations for different modes - if mode == "resample": - warnings.warn( - "Bootstrap resampling mode uses the same data for training and validation, " - "which may lead to overoptimistic importance scores with some models. " - "Consider using mode='jackknife' for more conservative estimates with proper train/test splits.", - UserWarning, - stacklevel=3, - ) - for i in range(iterations): if mode == "jackknife": X_keep, X_val, y_keep, y_val = train_test_split( X_reg, y_reg, train_size=jackknife_fraction ) - - # Check for small validation sets that may affect statistical reliability - if len(y_val) < 20: - warnings.warn( - f"Jackknife validation set is small (n={len(y_val)}). " - f"Small validation sets may lead to unreliable importance estimates. " - f"Consider increasing jackknife_fraction (currently {jackknife_fraction}) " - f"or using a larger dataset for more stable results.", - UserWarning, - stacklevel=3, - ) elif mode == "resample": - X_keep, y_keep = resample(X_reg, y_reg, replace=True) - X_val, y_val = X_reg, y_reg # Use full data for validation + X_keep, X_val, y_keep, y_val = train_test_split( + X_reg, y_reg, train_size=1-resample_validation_fraction + ) + X_keep, y_keep = resample(X_keep, y_keep, replace=True) else: raise NotImplementedError( f"Mode {mode} is unsupported, only jackknife and resample are valid modes" ) + + # Check for small validation sets that may affect statistical reliability + if len(y_val) < 20: + warnings.warn( + f"The validation set is small (n={len(y_val)}). " + f"Small validation sets may lead to unreliable importance estimates. " + f"Consider decreasing jackknife_fraction (currently {jackknife_fraction}), increasing resample_validation_fraction (currently {resample_validation_fraction}), " + f"or using a larger dataset for more stable results.", + UserWarning, + stacklevel=3, + ) + # Fit model and use sklearn's permutation_importance lg = LogisticRegression() if clf is None else clf lg.fit(X_keep, y_keep) @@ -279,6 +274,7 @@ def feature_importance( y: str, mode="resample", jackknife_fraction=0.8, + resample_validation_fraction=0.2, iterations=100, confounders=None, clf=None, @@ -296,6 +292,7 @@ def feature_importance( :param y: The name of the target variable. :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "jackknife". :param jackknife_fraction: Fraction of data to keep in each jackknife iteration (only used if mode="jackknife"). + :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). :param iterations: Number of resampling or jackknife iterations. :param confounders: List of tuples (feature, reference value) pairs representing confounder features and their reference values. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. @@ -326,6 +323,7 @@ def feature_importance( y_reg=y_reg, mode=mode, jackknife_fraction=jackknife_fraction, + resample_validation_fraction=resample_validation_fraction, iterations=iterations, clf=clf, ) diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index f52fb47..2031a81 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -219,11 +219,9 @@ def test_iterations_parameter(self, binary_sample_data): def test_mode_resample(self, binary_sample_data): """Test resample mode.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) - - with pytest.warns(UserWarning, match="Bootstrap resampling mode"): - result = _get_feature_importance( - "x", X_reg, y_reg, mode="resample", iterations=10 - ) + result = _get_feature_importance( + "x", X_reg, y_reg, mode="resample", iterations=10 + ) assert result["mode"] == "resample" @@ -490,15 +488,6 @@ def test_jackknife_mode(self, binary_sample_data): assert result["mode"] == "jackknife" - def test_resample_mode_warning(self, binary_sample_data): - """Test that resample mode issues a warning.""" - with pytest.warns(UserWarning, match="Bootstrap resampling mode"): - result = feature_importance( - binary_sample_data, x="x", y="y", mode="resample", iterations=10 - ) - - assert result["mode"] == "resample" - def test_small_validation_set_warning(self): """Test warning for small validation sets in jackknife mode.""" small_data = pd.DataFrame( @@ -508,7 +497,7 @@ def test_small_validation_set_warning(self): } ) - with pytest.warns(UserWarning, match="Jackknife validation set is small"): + with pytest.warns(UserWarning, match="The validation set is small"): result = feature_importance( small_data, x="x", @@ -520,6 +509,18 @@ def test_small_validation_set_warning(self): assert result["mode"] == "jackknife" + with pytest.warns(UserWarning, match="The validation set is small"): + result = feature_importance( + small_data, + x="x", + y="y", + mode="resample", + resample_validation_fraction=0.2, + iterations=5, + ) + + assert result["mode"] == "resample" + def test_no_warning_adequate_validation(self): """Test no warning for adequate validation sets.""" np.random.seed(42) From 8e3510d4d541abf8a6d43114c85cfab0d69dc722 Mon Sep 17 00:00:00 2001 From: autoblack_push <${GITHUB_ACTOR}@users.noreply.github.com> Date: Wed, 25 Feb 2026 08:20:57 +0000 Subject: [PATCH 03/11] fixup! Format Python code with psf/black push --- src/lorepy/uncertainty.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lorepy/uncertainty.py b/src/lorepy/uncertainty.py index d1f1d52..4a53b89 100644 --- a/src/lorepy/uncertainty.py +++ b/src/lorepy/uncertainty.py @@ -112,7 +112,7 @@ def _get_feature_importance( ) elif mode == "resample": X_keep, X_val, y_keep, y_val = train_test_split( - X_reg, y_reg, train_size=1-resample_validation_fraction + X_reg, y_reg, train_size=1 - resample_validation_fraction ) X_keep, y_keep = resample(X_keep, y_keep, replace=True) else: @@ -120,7 +120,6 @@ def _get_feature_importance( f"Mode {mode} is unsupported, only jackknife and resample are valid modes" ) - # Check for small validation sets that may affect statistical reliability if len(y_val) < 20: warnings.warn( From 64c275b1a12f6b4550e7e8790dac4cb93e2da719 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 08:31:25 +0000 Subject: [PATCH 04/11] adding accuracies to output --- README.md | 8 ++++++-- src/lorepy/uncertainty.py | 21 +++++++++++++++++++-- tests/test_uncertainty.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bd2554b..b3afc57 100644 --- a/README.md +++ b/README.md @@ -233,10 +233,14 @@ print(stats['interpretation']) The function returns a dictionary with the following key statistics: - **`mean_importance`**: Average accuracy drop when x-feature is shuffled (higher = more important) -- **`std_importance`**: Standard deviation of importance across iterations +- **`std_importance`**: Standard deviation of importance across iterations - **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate +- **`mean_validation_accuracy`**: Mean accuracy on the validation data across iterations +- **`std_validation_accuracy`**: Standard deviation of the validation accuracy +- **`mean_permuted_accuracy`**: Mean accuracy on the permuted data across iterations +- **`std_permuted_accuracy`**: Standard deviation of the permuted accuracy - **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction) -- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction) +- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction) - **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant) - **`interpretation`**: Human-readable summary of the results diff --git a/src/lorepy/uncertainty.py b/src/lorepy/uncertainty.py index 4a53b89..a4c0230 100644 --- a/src/lorepy/uncertainty.py +++ b/src/lorepy/uncertainty.py @@ -100,10 +100,12 @@ def _get_feature_importance( :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). :param iterations: Number of resampling or jackknife iterations. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. - :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, and significance metrics. + :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted accuracy statistics, and significance metrics. """ importance_scores = [] + validation_accuracies = [] + permuted_accuracies = [] for i in range(iterations): if mode == "jackknife": @@ -151,11 +153,22 @@ def _get_feature_importance( importance = perm_result.importances_mean[0] importance_scores.append(importance) + # Track validation and permuted accuracies + val_accuracy = lg.score(X_val, y_val) + validation_accuracies.append(val_accuracy) + permuted_accuracies.append(val_accuracy - importance) + importance_scores = np.array(importance_scores) + validation_accuracies = np.array(validation_accuracies) + permuted_accuracies = np.array(permuted_accuracies) # Calculate statistics mean_importance = np.mean(importance_scores) std_importance = np.std(importance_scores) + mean_validation_accuracy = np.mean(validation_accuracies) + std_validation_accuracy = np.std(validation_accuracies) + mean_permuted_accuracy = np.mean(permuted_accuracies) + std_permuted_accuracy = np.std(permuted_accuracies) ci_95_low = np.percentile(importance_scores, 2.5) ci_95_high = np.percentile(importance_scores, 97.5) @@ -176,6 +189,10 @@ def _get_feature_importance( "std_importance": std_importance, "importance_95ci_low": ci_95_low, "importance_95ci_high": ci_95_high, + "mean_validation_accuracy": mean_validation_accuracy, + "std_validation_accuracy": std_validation_accuracy, + "mean_permuted_accuracy": mean_permuted_accuracy, + "std_permuted_accuracy": std_permuted_accuracy, "proportion_positive": significant_positive, "proportion_negative": significant_negative, "p_value": p_value, @@ -295,7 +312,7 @@ def feature_importance( :param iterations: Number of resampling or jackknife iterations. :param confounders: List of tuples (feature, reference value) pairs representing confounder features and their reference values. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. - :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, and significance metrics. + :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted accuracy statistics, and significance metrics. Example: >>> import pandas as pd diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index 2031a81..f9fbc46 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -189,6 +189,10 @@ def test_basic_output_structure(self, binary_sample_data): "std_importance", "importance_95ci_low", "importance_95ci_high", + "mean_validation_accuracy", + "std_validation_accuracy", + "mean_permuted_accuracy", + "std_permuted_accuracy", "proportion_positive", "proportion_negative", "p_value", @@ -200,6 +204,28 @@ def test_basic_output_structure(self, binary_sample_data): for key in expected_keys: assert key in result, f"Missing key: {key}" + def test_accuracy_values_in_valid_range(self, binary_sample_data): + """Test that accuracy values are between 0 and 1.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=10) + + assert 0 <= result["mean_validation_accuracy"] <= 1 + assert 0 <= result["mean_permuted_accuracy"] <= 1 + assert result["std_validation_accuracy"] >= 0 + assert result["std_permuted_accuracy"] >= 0 + + def test_importance_equals_accuracy_difference(self, binary_sample_data): + """Test that mean importance approximately equals the difference between validation and permuted accuracy.""" + X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) + + result = _get_feature_importance("x", X_reg, y_reg, iterations=50) + + # importance = validation_accuracy - permuted_accuracy (per iteration), + # so mean_importance should approximately equal the difference of means + expected_diff = result["mean_validation_accuracy"] - result["mean_permuted_accuracy"] + assert abs(result["mean_importance"] - expected_diff) < 0.05 + def test_feature_name_preserved(self, binary_sample_data): """Test that feature name is preserved in output.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) @@ -562,6 +588,10 @@ def test_output_consistency(self, binary_sample_data): "std_importance", "importance_95ci_low", "importance_95ci_high", + "mean_validation_accuracy", + "std_validation_accuracy", + "mean_permuted_accuracy", + "std_permuted_accuracy", "proportion_positive", "proportion_negative", "p_value", From e277a410849f69a9e7372dc0725e934e280599e2 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 08:48:57 +0000 Subject: [PATCH 05/11] switch from accuracy to neg_log_loss which is better suited for the feature importance --- README.md | 14 +++++------ src/lorepy/uncertainty.py | 53 +++++++++++++++++++++++---------------- tests/test_uncertainty.py | 46 +++++++++++++++++---------------- 3 files changed, 62 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index b3afc57..37bdf52 100644 --- a/README.md +++ b/README.md @@ -205,7 +205,7 @@ This also supports custom colors, ranges and classifiers. More examples are avai ### Feature Importance Analysis -Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** to test whether the relationship you see in your loreplot is statistically significant. +Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** with **log loss (cross-entropy)** as the scoring metric to test whether the relationship you see in your loreplot is statistically significant. Log loss evaluates the full predicted probability distribution rather than just hard class predictions, making it well-suited for lorepy's probability-based visualizations. #### How it Works @@ -214,7 +214,7 @@ The function uses a robust resampling approach combined with sklearn's optimized 1. **Bootstrap/Jackknife Sampling**: Creates multiple subsamples of your data (default: 100 iterations) 2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage 3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact -4. **Performance Assessment**: Measures accuracy drop using statistically sound train/test splits +4. **Performance Assessment**: Measures log loss increase using statistically sound train/test splits 5. **Statistical Summary**: Aggregates results across all iterations to provide confidence intervals and significance testing This approach works with **any sklearn classifier** (LogisticRegression, SVM, RandomForest, etc.) and properly handles confounders by keeping them constant during shuffling. The implementation uses sklearn's battle-tested permutation importance algorithm for reliable, unbiased results. @@ -232,13 +232,13 @@ print(stats['interpretation']) The function returns a dictionary with the following key statistics: -- **`mean_importance`**: Average accuracy drop when x-feature is shuffled (higher = more important) +- **`mean_importance`**: Average log loss increase when x-feature is shuffled (higher = more important) - **`std_importance`**: Standard deviation of importance across iterations - **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate -- **`mean_validation_accuracy`**: Mean accuracy on the validation data across iterations -- **`std_validation_accuracy`**: Standard deviation of the validation accuracy -- **`mean_permuted_accuracy`**: Mean accuracy on the permuted data across iterations -- **`std_permuted_accuracy`**: Standard deviation of the permuted accuracy +- **`mean_validation_log_loss`**: Mean log loss on the validation data across iterations (lower = better) +- **`std_validation_log_loss`**: Standard deviation of the validation log loss +- **`mean_permuted_log_loss`**: Mean log loss on the permuted data across iterations (lower = better) +- **`std_permuted_log_loss`**: Standard deviation of the permuted log loss - **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction) - **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction) - **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant) diff --git a/src/lorepy/uncertainty.py b/src/lorepy/uncertainty.py index a4c0230..65f6565 100644 --- a/src/lorepy/uncertainty.py +++ b/src/lorepy/uncertainty.py @@ -7,6 +7,7 @@ from sklearn.model_selection import train_test_split from sklearn.utils import resample from sklearn.inspection import permutation_importance +from sklearn.metrics import log_loss from lorepy.lorepy import _get_area_df, _prepare_data @@ -90,7 +91,9 @@ def _get_feature_importance( ): """ Estimates the importance of the x-feature in predicting class labels using permutation-based - feature importance with resampling or jackknife methods. Uses accuracy as the performance metric. + feature importance with resampling or jackknife methods. Uses log loss (cross-entropy) as the + performance metric, which evaluates the full predicted probability distribution rather than + just hard class predictions. :param x: Name of the feature variable to analyze for importance. :param X_reg: Feature matrix for regression/classification. @@ -100,12 +103,12 @@ def _get_feature_importance( :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). :param iterations: Number of resampling or jackknife iterations. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. - :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted accuracy statistics, and significance metrics. + :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted log loss statistics, and significance metrics. """ importance_scores = [] - validation_accuracies = [] - permuted_accuracies = [] + validation_log_losses = [] + permuted_log_losses = [] for i in range(iterations): if mode == "jackknife": @@ -138,37 +141,41 @@ def _get_feature_importance( lg.fit(X_keep, y_keep) # Use permutation_importance to get feature importance for first feature (x) - # This handles proper train/test splits internally and avoids training data leakage + # Using neg_log_loss to evaluate the full probability distribution rather than + # just hard class predictions, which is more appropriate for lorepy's probability-based plots perm_result = permutation_importance( lg, X_val, y_val, n_repeats=1, # We handle iterations in outer loop random_state=None, # Allow randomness for each iteration - scoring="accuracy", + scoring="neg_log_loss", n_jobs=1, ) # Extract importance for first feature (x-feature) + # importance = neg_log_loss_original - neg_log_loss_permuted + # Positive importance means permuting the feature increases log loss (worsens predictions) importance = perm_result.importances_mean[0] importance_scores.append(importance) - # Track validation and permuted accuracies - val_accuracy = lg.score(X_val, y_val) - validation_accuracies.append(val_accuracy) - permuted_accuracies.append(val_accuracy - importance) + # Track validation and permuted log losses + val_log_loss = log_loss(y_val, lg.predict_proba(X_val), labels=lg.classes_) + validation_log_losses.append(val_log_loss) + # Since importance = (-val_log_loss) - (-perm_log_loss) = perm_log_loss - val_log_loss + permuted_log_losses.append(val_log_loss + importance) importance_scores = np.array(importance_scores) - validation_accuracies = np.array(validation_accuracies) - permuted_accuracies = np.array(permuted_accuracies) + validation_log_losses = np.array(validation_log_losses) + permuted_log_losses = np.array(permuted_log_losses) # Calculate statistics mean_importance = np.mean(importance_scores) std_importance = np.std(importance_scores) - mean_validation_accuracy = np.mean(validation_accuracies) - std_validation_accuracy = np.std(validation_accuracies) - mean_permuted_accuracy = np.mean(permuted_accuracies) - std_permuted_accuracy = np.std(permuted_accuracies) + mean_validation_log_loss = np.mean(validation_log_losses) + std_validation_log_loss = np.std(validation_log_losses) + mean_permuted_log_loss = np.mean(permuted_log_losses) + std_permuted_log_loss = np.std(permuted_log_losses) ci_95_low = np.percentile(importance_scores, 2.5) ci_95_high = np.percentile(importance_scores, 97.5) @@ -189,10 +196,10 @@ def _get_feature_importance( "std_importance": std_importance, "importance_95ci_low": ci_95_low, "importance_95ci_high": ci_95_high, - "mean_validation_accuracy": mean_validation_accuracy, - "std_validation_accuracy": std_validation_accuracy, - "mean_permuted_accuracy": mean_permuted_accuracy, - "std_permuted_accuracy": std_permuted_accuracy, + "mean_validation_log_loss": mean_validation_log_loss, + "std_validation_log_loss": std_validation_log_loss, + "mean_permuted_log_loss": mean_permuted_log_loss, + "std_permuted_log_loss": std_permuted_log_loss, "proportion_positive": significant_positive, "proportion_negative": significant_negative, "p_value": p_value, @@ -297,7 +304,9 @@ def feature_importance( ): """ Estimates the importance of a feature in predicting class labels using permutation-based - feature importance with resampling or jackknife methods. Uses accuracy as the performance metric. + feature importance with resampling or jackknife methods. Uses log loss (cross-entropy) as the + performance metric, which evaluates the full predicted probability distribution rather than + just hard class predictions. This function provides statistical assessment of whether the x-feature is significantly associated with the class distribution (y-variable). Higher importance scores indicate @@ -312,7 +321,7 @@ def feature_importance( :param iterations: Number of resampling or jackknife iterations. :param confounders: List of tuples (feature, reference value) pairs representing confounder features and their reference values. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. - :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted accuracy statistics, and significance metrics. + :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted log loss statistics, and significance metrics. Example: >>> import pandas as pd diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index f9fbc46..856a9ee 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -189,10 +189,10 @@ def test_basic_output_structure(self, binary_sample_data): "std_importance", "importance_95ci_low", "importance_95ci_high", - "mean_validation_accuracy", - "std_validation_accuracy", - "mean_permuted_accuracy", - "std_permuted_accuracy", + "mean_validation_log_loss", + "std_validation_log_loss", + "mean_permuted_log_loss", + "std_permuted_log_loss", "proportion_positive", "proportion_negative", "p_value", @@ -204,27 +204,27 @@ def test_basic_output_structure(self, binary_sample_data): for key in expected_keys: assert key in result, f"Missing key: {key}" - def test_accuracy_values_in_valid_range(self, binary_sample_data): - """Test that accuracy values are between 0 and 1.""" + def test_log_loss_values_valid(self, binary_sample_data): + """Test that log loss values are non-negative and standard deviations are non-negative.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) result = _get_feature_importance("x", X_reg, y_reg, iterations=10) - assert 0 <= result["mean_validation_accuracy"] <= 1 - assert 0 <= result["mean_permuted_accuracy"] <= 1 - assert result["std_validation_accuracy"] >= 0 - assert result["std_permuted_accuracy"] >= 0 + assert result["mean_validation_log_loss"] >= 0 + assert result["mean_permuted_log_loss"] >= 0 + assert result["std_validation_log_loss"] >= 0 + assert result["std_permuted_log_loss"] >= 0 - def test_importance_equals_accuracy_difference(self, binary_sample_data): - """Test that mean importance approximately equals the difference between validation and permuted accuracy.""" + def test_permuted_log_loss_generally_higher(self, binary_sample_data): + """Test that permuted log loss is generally higher (worse) than validation log loss for an informative feature.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) result = _get_feature_importance("x", X_reg, y_reg, iterations=50) - # importance = validation_accuracy - permuted_accuracy (per iteration), - # so mean_importance should approximately equal the difference of means - expected_diff = result["mean_validation_accuracy"] - result["mean_permuted_accuracy"] - assert abs(result["mean_importance"] - expected_diff) < 0.05 + # For an informative feature, permuting should increase log loss + # This is a soft check - the relationship holds on average + if result["mean_importance"] > 0: + assert result["mean_permuted_log_loss"] >= result["mean_validation_log_loss"] def test_feature_name_preserved(self, binary_sample_data): """Test that feature name is preserved in output.""" @@ -516,10 +516,12 @@ def test_jackknife_mode(self, binary_sample_data): def test_small_validation_set_warning(self): """Test warning for small validation sets in jackknife mode.""" + # Use balanced classes to ensure both classes appear in validation splits, + # which is required for log_loss scoring small_data = pd.DataFrame( { - "x": np.random.randn(15), - "y": np.random.choice([0, 1], 15), + "x": np.random.randn(50), + "y": np.array([0, 1] * 25), } ) @@ -588,10 +590,10 @@ def test_output_consistency(self, binary_sample_data): "std_importance", "importance_95ci_low", "importance_95ci_high", - "mean_validation_accuracy", - "std_validation_accuracy", - "mean_permuted_accuracy", - "std_permuted_accuracy", + "mean_validation_log_loss", + "std_validation_log_loss", + "mean_permuted_log_loss", + "std_permuted_log_loss", "proportion_positive", "proportion_negative", "p_value", From 0b8a7ac74ac5df7ccb373e9bb414799fafeb57ab Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 09:01:54 +0000 Subject: [PATCH 06/11] expanded docs --- README.md | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 37bdf52..6f60be4 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,34 @@ If you prefer an R implementation of this package, have a look at [loreplotr](ht ## Why use lorepy ? -Lorepy offers distinct advantages over traditional methods like stacked bar plots. By employing a linear model, Lorepy -captures overall trends across the entire feature range. It avoids arbitrary cut-offs and segmentation, enabling the +Lorepy offers distinct advantages over traditional methods like stacked bar plots. By employing a linear model, Lorepy +captures overall trends across the entire feature range. It avoids arbitrary cut-offs and segmentation, enabling the visualization of uncertainty throughout the data range. You can find examples of the Iris data visualized using stacked bar plots [here](https://github.com/raeslab/lorepy/blob/main/docs/lorepy_vs_bar_plots.md) for comparison. +## How lorepy works + +Under the hood, the default model is a multinomial logistic regression (scikit-learn's `LogisticRegression` with L2 +regularization, C=1.0). For *K* classes it estimates coefficient vectors **β**\_k and intercepts β\_{k0}, then computes +class probabilities via the softmax function: *P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0})*. +Because these probabilities are guaranteed to sum to one for every value of *x*, they can be directly rendered as a +stacked area chart. Lorepy evaluates the fitted model at 200 evenly spaced points across the observed feature range, +producing smooth probability curves that reveal how the expected class composition shifts continuously with the +independent variable. When confounders are specified, they are included as additional features during model fitting but +held constant at user-supplied reference values during prediction, effectively marginalizing their influence. Sample +dots are positioned by drawing a random *y*-coordinate within the predicted probability band of each observation's true +class, giving an intuitive sense of both class membership and local model confidence. + +Concretely, the height of each colored band at a given *x*-value represents the model's estimated proportion of that +class: a band spanning 60% of the y-axis means the model estimates that class accounts for 60% of observations at that +point along the feature. As *x* increases, bands that widen indicate classes with a growing estimated proportion, while +narrowing bands indicate classes becoming rarer. A class that dominates the plot across the full range has a +consistently high estimated proportion regardless of the feature value, whereas a sharp crossover between two bands +pinpoints where one class overtakes another. Because the bands always sum to one, the plot naturally encodes a zero-sum +trade-off: one class can only grow in estimated proportion at the expense of others, making it straightforward to read +both absolute and relative shifts directly from the visualization. + ## Installation Lorepy can be installed using pip using the command below. From d23e93f322471ecadae64fa850c6687f0adc8820 Mon Sep 17 00:00:00 2001 From: autoblack_push <${GITHUB_ACTOR}@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:02:12 +0000 Subject: [PATCH 07/11] fixup! Format Python code with psf/black push --- tests/test_uncertainty.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index 856a9ee..be1ab85 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -224,7 +224,9 @@ def test_permuted_log_loss_generally_higher(self, binary_sample_data): # For an informative feature, permuting should increase log loss # This is a soft check - the relationship holds on average if result["mean_importance"] > 0: - assert result["mean_permuted_log_loss"] >= result["mean_validation_log_loss"] + assert ( + result["mean_permuted_log_loss"] >= result["mean_validation_log_loss"] + ) def test_feature_name_preserved(self, binary_sample_data): """Test that feature name is preserved in output.""" From bd1a98ae6287b0d28bcae8eed15196c3f42a822f Mon Sep 17 00:00:00 2001 From: Sebastian Proost Date: Wed, 25 Feb 2026 10:13:38 +0100 Subject: [PATCH 08/11] Clarify README explanation of lorepy model mechanics (#16) --- README.md | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6f60be4..beed1ca 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,19 @@ You can find examples of the Iris data visualized using stacked bar plots [here] ## How lorepy works -Under the hood, the default model is a multinomial logistic regression (scikit-learn's `LogisticRegression` with L2 -regularization, C=1.0). For *K* classes it estimates coefficient vectors **β**\_k and intercepts β\_{k0}, then computes -class probabilities via the softmax function: *P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0})*. -Because these probabilities are guaranteed to sum to one for every value of *x*, they can be directly rendered as a -stacked area chart. Lorepy evaluates the fitted model at 200 evenly spaced points across the observed feature range, -producing smooth probability curves that reveal how the expected class composition shifts continuously with the -independent variable. When confounders are specified, they are included as additional features during model fitting but -held constant at user-supplied reference values during prediction, effectively marginalizing their influence. Sample -dots are positioned by drawing a random *y*-coordinate within the predicted probability band of each observation's true -class, giving an intuitive sense of both class membership and local model confidence. +Under the hood, lorepy fits scikit-learn's default `LogisticRegression()` model (L2 regularization, `C=1.0`) unless you +pass a custom classifier. For a multiclass outcome with *K* classes, the fitted model can be written as coefficient +vectors **β**\_k and intercepts β\_{k0}, with class probabilities + +*P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0}).* + +These probabilities sum to one at each *x*, so they can be drawn directly as a stacked area chart. Lorepy evaluates +the fitted model at 200 evenly spaced points across the observed range of the x-feature, yielding smooth curves for the +estimated class composition as a function of *x*. If confounders are provided, they are included during fitting and +then fixed to user-specified reference values at prediction time, so the displayed curves are **conditional on those +reference values** (not marginal averages over the confounder distribution). Sample dots are generated by drawing each +point's *y*-coordinate uniformly within the predicted probability interval of its true class, which visualizes class +membership while preserving the stacked-probability interpretation. Concretely, the height of each colored band at a given *x*-value represents the model's estimated proportion of that class: a band spanning 60% of the y-axis means the model estimates that class accounts for 60% of observations at that From 269f1321ff45d60c11b4922c09d65776ec8c9124 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 09:22:00 +0000 Subject: [PATCH 09/11] refactored docs --- README.md | 91 +------------------------------------- docs/feature_importance.md | 67 ++++++++++++++++++++++++++++ docs/how_lorepy_works.md | 24 ++++++++++ 3 files changed, 93 insertions(+), 89 deletions(-) create mode 100644 docs/feature_importance.md create mode 100644 docs/how_lorepy_works.md diff --git a/README.md b/README.md index beed1ca..aa4acad 100644 --- a/README.md +++ b/README.md @@ -19,28 +19,7 @@ You can find examples of the Iris data visualized using stacked bar plots [here] ## How lorepy works -Under the hood, lorepy fits scikit-learn's default `LogisticRegression()` model (L2 regularization, `C=1.0`) unless you -pass a custom classifier. For a multiclass outcome with *K* classes, the fitted model can be written as coefficient -vectors **β**\_k and intercepts β\_{k0}, with class probabilities - -*P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0}).* - -These probabilities sum to one at each *x*, so they can be drawn directly as a stacked area chart. Lorepy evaluates -the fitted model at 200 evenly spaced points across the observed range of the x-feature, yielding smooth curves for the -estimated class composition as a function of *x*. If confounders are provided, they are included during fitting and -then fixed to user-specified reference values at prediction time, so the displayed curves are **conditional on those -reference values** (not marginal averages over the confounder distribution). Sample dots are generated by drawing each -point's *y*-coordinate uniformly within the predicted probability interval of its true class, which visualizes class -membership while preserving the stacked-probability interpretation. - -Concretely, the height of each colored band at a given *x*-value represents the model's estimated proportion of that -class: a band spanning 60% of the y-axis means the model estimates that class accounts for 60% of observations at that -point along the feature. As *x* increases, bands that widen indicate classes with a growing estimated proportion, while -narrowing bands indicate classes becoming rarer. A class that dominates the plot across the full range has a -consistently high estimated proportion regardless of the feature value, whereas a sharp crossover between two bands -pinpoints where one class overtakes another. Because the bands always sum to one, the plot naturally encodes a zero-sum -trade-off: one class can only grow in estimated proportion at the expense of others, making it straightforward to read -both absolute and relative shifts directly from the visualization. +For details on the model mechanics and how to interpret loreplots, see [How lorepy works](https://github.com/raeslab/lorepy/blob/main/docs/how_lorepy_works.md). ## Installation @@ -230,73 +209,7 @@ This also supports custom colors, ranges and classifiers. More examples are avai ### Feature Importance Analysis -Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** with **log loss (cross-entropy)** as the scoring metric to test whether the relationship you see in your loreplot is statistically significant. Log loss evaluates the full predicted probability distribution rather than just hard class predictions, making it well-suited for lorepy's probability-based visualizations. - -#### How it Works - -The function uses a robust resampling approach combined with sklearn's optimized permutation importance: - -1. **Bootstrap/Jackknife Sampling**: Creates multiple subsamples of your data (default: 100 iterations) -2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage -3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact -4. **Performance Assessment**: Measures log loss increase using statistically sound train/test splits -5. **Statistical Summary**: Aggregates results across all iterations to provide confidence intervals and significance testing - -This approach works with **any sklearn classifier** (LogisticRegression, SVM, RandomForest, etc.) and properly handles confounders by keeping them constant during shuffling. The implementation uses sklearn's battle-tested permutation importance algorithm for reliable, unbiased results. - -```python -from lorepy import feature_importance - -# Basic usage -stats = feature_importance(data=iris_df, x="sepal width (cm)", y="species", iterations=100) -print(stats['interpretation']) -# Output: "Feature importance: 0.2019 ± 0.0433. Positive in 100.0% of iterations (p=0.0000)" -``` - -#### Understanding the Output - -The function returns a dictionary with the following key statistics: - -- **`mean_importance`**: Average log loss increase when x-feature is shuffled (higher = more important) -- **`std_importance`**: Standard deviation of importance across iterations -- **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate -- **`mean_validation_log_loss`**: Mean log loss on the validation data across iterations (lower = better) -- **`std_validation_log_loss`**: Standard deviation of the validation log loss -- **`mean_permuted_log_loss`**: Mean log loss on the permuted data across iterations (lower = better) -- **`std_permuted_log_loss`**: Standard deviation of the permuted log loss -- **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction) -- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction) -- **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant) -- **`interpretation`**: Human-readable summary of the results - -#### Advanced Usage - -```python -from sklearn.svm import SVC - -# With confounders and custom classifier -stats = feature_importance( - data=data, - x="age", - y="disease", - confounders=[("bmi", 25), ("sex", "female")], # Control for these variables - clf=SVC(probability=True), # Use SVM instead of logistic regression - mode="jackknife", # Use jackknife instead of bootstrap - iterations=200 # More iterations for precision -) - -print(f"P-value: {stats['p_value']:.4f}") -print(f"95% CI: [{stats['importance_95ci_low']:.3f}, {stats['importance_95ci_high']:.3f}]") -``` - -#### Interpretation Guidelines - -- **Strong Association**: `p_value < 0.01`, `proportion_positive > 95%` -- **Moderate Association**: `p_value < 0.05`, `proportion_positive > 80%` -- **Weak/No Association**: `p_value > 0.05`, confidence interval includes zero -- **Negative Association**: `proportion_negative > proportion_positive` (unusual but possible) - - +For details on permutation-based feature importance analysis, see [Feature Importance Analysis](https://github.com/raeslab/lorepy/blob/main/docs/feature_importance.md). ## Development diff --git a/docs/feature_importance.md b/docs/feature_importance.md new file mode 100644 index 0000000..db45f75 --- /dev/null +++ b/docs/feature_importance.md @@ -0,0 +1,67 @@ +# Feature Importance Analysis + +Lorepy provides statistical assessment of how strongly your x-feature is associated with the class distribution using the `feature_importance` function. This uses **permutation-based feature importance** with **log loss (cross-entropy)** as the scoring metric to test whether the relationship you see in your loreplot is statistically significant. Log loss evaluates the full predicted probability distribution rather than just hard class predictions, making it well-suited for lorepy's probability-based visualizations. + +## How it Works + +The function uses a robust resampling approach combined with sklearn's optimized permutation importance: + +1. **Bootstrap/Jackknife Sampling**: Creates multiple subsamples of your data (default: 100 iterations) +2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage +3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact +4. **Performance Assessment**: Measures log loss increase using statistically sound train/test splits +5. **Statistical Summary**: Aggregates results across all iterations to provide confidence intervals and significance testing + +This approach works with **any sklearn classifier** (LogisticRegression, SVM, RandomForest, etc.) and properly handles confounders by keeping them constant during shuffling. The implementation uses sklearn's battle-tested permutation importance algorithm for reliable, unbiased results. + +```python +from lorepy import feature_importance + +# Basic usage +stats = feature_importance(data=iris_df, x="sepal width (cm)", y="species", iterations=100) +print(stats['interpretation']) +# Output: "Feature importance: 0.2019 ± 0.0433. Positive in 100.0% of iterations (p=0.0000)" +``` + +## Understanding the Output + +The function returns a dictionary with the following key statistics: + +- **`mean_importance`**: Average log loss increase when x-feature is shuffled (higher = more important) +- **`std_importance`**: Standard deviation of importance across iterations +- **`importance_95ci_low/high`**: 95% confidence interval for the importance estimate +- **`mean_validation_log_loss`**: Mean log loss on the validation data across iterations (lower = better) +- **`std_validation_log_loss`**: Standard deviation of the validation log loss +- **`mean_permuted_log_loss`**: Mean log loss on the permuted data across iterations (lower = better) +- **`std_permuted_log_loss`**: Standard deviation of the permuted log loss +- **`proportion_positive`**: Fraction of iterations where importance > 0 (feature helps prediction) +- **`proportion_negative`**: Fraction of iterations where importance < 0 (feature hurts prediction) +- **`p_value`**: Empirical p-value for statistical significance (< 0.05 typically considered significant) +- **`interpretation`**: Human-readable summary of the results + +## Advanced Usage + +```python +from sklearn.svm import SVC + +# With confounders and custom classifier +stats = feature_importance( + data=data, + x="age", + y="disease", + confounders=[("bmi", 25), ("sex", "female")], # Control for these variables + clf=SVC(probability=True), # Use SVM instead of logistic regression + mode="jackknife", # Use jackknife instead of bootstrap + iterations=200 # More iterations for precision +) + +print(f"P-value: {stats['p_value']:.4f}") +print(f"95% CI: [{stats['importance_95ci_low']:.3f}, {stats['importance_95ci_high']:.3f}]") +``` + +## Interpretation Guidelines + +- **Strong Association**: `p_value < 0.01`, `proportion_positive > 95%` +- **Moderate Association**: `p_value < 0.05`, `proportion_positive > 80%` +- **Weak/No Association**: `p_value > 0.05`, confidence interval includes zero +- **Negative Association**: `proportion_negative > proportion_positive` (unusual but possible) diff --git a/docs/how_lorepy_works.md b/docs/how_lorepy_works.md new file mode 100644 index 0000000..ee8d04e --- /dev/null +++ b/docs/how_lorepy_works.md @@ -0,0 +1,24 @@ +# How lorepy works + +Under the hood, lorepy fits scikit-learn's default `LogisticRegression()` model (L2 regularization, `C=1.0`) unless you +pass a custom classifier. For a multiclass outcome with *K* classes, the fitted model can be written as coefficient +vectors **β**\_k and intercepts β\_{k0}, with class probabilities: + +*P(Y=k | x) = exp(**β**\_k · x + β\_{k0}) / Σ\_j exp(**β**\_j · x + β\_{j0})* + +These probabilities sum to one at each *x*, so they can be drawn directly as a stacked area chart. Lorepy evaluates +the fitted model at 200 evenly spaced points across the observed range of the x-feature, yielding smooth curves for the +estimated class composition as a function of *x*. If confounders are provided, they are included during fitting and +then fixed to user-specified reference values at prediction time, so the displayed curves are **conditional on those +reference values** (not marginal averages over the confounder distribution). Sample dots are generated by drawing each +point's *y*-coordinate uniformly within the predicted probability interval of its true class, which visualizes class +membership while preserving the stacked-probability interpretation. + +Concretely, the height of each colored band at a given *x*-value represents the model's estimated proportion of that +class: a band spanning 60% of the y-axis means the model estimates that class accounts for 60% of observations at that +point along the feature. As *x* increases, bands that widen indicate classes with a growing estimated proportion, while +narrowing bands indicate classes becoming rarer. A class that dominates the plot across the full range has a +consistently high estimated proportion regardless of the feature value, whereas a sharp crossover between two bands +pinpoints where one class overtakes another. Because the bands always sum to one, the plot naturally encodes a zero-sum +trade-off: one class can only grow in estimated proportion at the expense of others, making it straightforward to read +both absolute and relative shifts directly from the visualization. From bf5f7d33c89594891a395047503f062949e45e11 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 25 Feb 2026 09:36:58 +0000 Subject: [PATCH 10/11] renamed jackknife to random_subsampling (more accurate) --- README.md | 2 +- docs/feature_importance.md | 4 +-- example_uncertainty.py | 8 ++--- src/lorepy/uncertainty.py | 58 +++++++++++++++---------------- tests/test_uncertainty.py | 70 +++++++++++++++++++------------------- 5 files changed, 71 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index aa4acad..0165f20 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ plt.show() From loreplots it isn't possible to assess how certain we are of the prevalence of each group across the range. To provide a view into this there is a function ```uncertainty_plot```, which can be used as shown below. This will use -```resampling``` (or ```jackknifing```) to determine the 50% and 95% interval of predicted values and show these in a +```resampling``` (or ```random subsampling```) to determine the 50% and 95% interval of predicted values and show these in a multi-panel plot with one plot per category. ```python diff --git a/docs/feature_importance.md b/docs/feature_importance.md index db45f75..e27ab78 100644 --- a/docs/feature_importance.md +++ b/docs/feature_importance.md @@ -6,7 +6,7 @@ Lorepy provides statistical assessment of how strongly your x-feature is associa The function uses a robust resampling approach combined with sklearn's optimized permutation importance: -1. **Bootstrap/Jackknife Sampling**: Creates multiple subsamples of your data (default: 100 iterations) +1. **Bootstrap/Random Subsampling**: Creates multiple subsamples of your data (default: 100 iterations) 2. **Permutation Importance**: For each subsample, uses sklearn's `permutation_importance` with proper cross-validation to avoid data leakage 3. **Feature Shuffling**: Randomly permutes the x-feature values while keeping confounders intact 4. **Performance Assessment**: Measures log loss increase using statistically sound train/test splits @@ -51,7 +51,7 @@ stats = feature_importance( y="disease", confounders=[("bmi", 25), ("sex", "female")], # Control for these variables clf=SVC(probability=True), # Use SVM instead of logistic regression - mode="jackknife", # Use jackknife instead of bootstrap + mode="random_subsampling", # Use random subsampling instead of bootstrap iterations=200 # More iterations for precision ) diff --git a/example_uncertainty.py b/example_uncertainty.py index 49cc362..28b8ef9 100644 --- a/example_uncertainty.py +++ b/example_uncertainty.py @@ -23,20 +23,20 @@ print(stats) stats = feature_importance( - data=iris_df, x="sepal width (cm)", y="species", iterations=100, mode="jackknife" + data=iris_df, x="sepal width (cm)", y="species", iterations=100, mode="random_subsampling" ) print(stats) -# Using jackknife instead of resample to assess uncertainty +# Using random subsampling instead of resample to assess uncertainty uncertainty_plot( data=iris_df, x="sepal width (cm)", y="species", iterations=100, - jackknife_fraction=0.8, + subsampling_fraction=0.8, ) -plt.savefig("./docs/img/uncertainty_jackknife.png", dpi=150) +plt.savefig("./docs/img/uncertainty_random_subsampling.png", dpi=150) plt.show() # Uncertainty plot with custom colors diff --git a/src/lorepy/uncertainty.py b/src/lorepy/uncertainty.py index 65f6565..8f95755 100644 --- a/src/lorepy/uncertainty.py +++ b/src/lorepy/uncertainty.py @@ -18,21 +18,21 @@ def _get_uncertainty_data( y_reg, x_range, mode="resample", - jackknife_fraction: float = 0.8, + subsampling_fraction: float = 0.8, iterations: int = 100, confounders=None, clf=None, ): """ - Estimates uncertainty in model predictions using resampling or jackknife methods. + Estimates uncertainty in model predictions using resampling or random subsampling methods. :param x: Name of the feature variable to analyze. :param X_reg: Feature matrix for regression/classification. :param y_reg: Target variable. :param x_range: Tuple (min, max) specifying the range of values for the feature variable `x` to evaluate. - :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "jackknife". - :param jackknife_fraction: Fraction of data to keep in each jackknife iteration (only used if mode="jackknife"). - :param iterations: Number of resampling or jackknife iterations. + :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "random_subsampling". + :param subsampling_fraction: Fraction of data to keep in each random subsampling iteration (only used if mode="random_subsampling"). + :param iterations: Number of resampling or random subsampling iterations. :param confounders: List of tuples (feature, reference value) pairs representing confounder features and their reference values. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. :return: Tuple containing output DataFrame with aggregated uncertainty statistics and long_df DataFrame with all resampled predictions. @@ -41,15 +41,15 @@ def _get_uncertainty_data( areas = [] for i in range(iterations): - if mode == "jackknife": + if mode == "random_subsampling": X_keep, _, y_keep, _ = train_test_split( - X_reg, y_reg, train_size=jackknife_fraction + X_reg, y_reg, train_size=subsampling_fraction ) elif mode == "resample": X_keep, y_keep = resample(X_reg, y_reg, replace=True) else: raise NotImplementedError( - f"Mode {mode} is unsupported, only jackknife and resample are valid modes" + f"Mode {mode} is unsupported, only random_subsampling and resample are valid modes" ) lg = LogisticRegression() if clf is None else clf @@ -83,25 +83,25 @@ def _get_feature_importance( x: str, X_reg, y_reg, - mode="jackknife", - jackknife_fraction: float = 0.8, + mode="random_subsampling", + subsampling_fraction: float = 0.8, resample_validation_fraction: float = 0.2, iterations: int = 100, clf=None, ): """ Estimates the importance of the x-feature in predicting class labels using permutation-based - feature importance with resampling or jackknife methods. Uses log loss (cross-entropy) as the + feature importance with resampling or random subsampling methods. Uses log loss (cross-entropy) as the performance metric, which evaluates the full predicted probability distribution rather than just hard class predictions. :param x: Name of the feature variable to analyze for importance. :param X_reg: Feature matrix for regression/classification. :param y_reg: Target variable. - :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "jackknife". - :param jackknife_fraction: Fraction of data to keep in each jackknife iteration (only used if mode="jackknife"). + :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "random_subsampling". + :param subsampling_fraction: Fraction of data to keep in each random subsampling iteration (only used if mode="random_subsampling"). :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). - :param iterations: Number of resampling or jackknife iterations. + :param iterations: Number of resampling or random subsampling iterations. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted log loss statistics, and significance metrics. """ @@ -111,9 +111,9 @@ def _get_feature_importance( permuted_log_losses = [] for i in range(iterations): - if mode == "jackknife": + if mode == "random_subsampling": X_keep, X_val, y_keep, y_val = train_test_split( - X_reg, y_reg, train_size=jackknife_fraction + X_reg, y_reg, train_size=subsampling_fraction ) elif mode == "resample": X_keep, X_val, y_keep, y_val = train_test_split( @@ -122,7 +122,7 @@ def _get_feature_importance( X_keep, y_keep = resample(X_keep, y_keep, replace=True) else: raise NotImplementedError( - f"Mode {mode} is unsupported, only jackknife and resample are valid modes" + f"Mode {mode} is unsupported, only random_subsampling and resample are valid modes" ) # Check for small validation sets that may affect statistical reliability @@ -130,7 +130,7 @@ def _get_feature_importance( warnings.warn( f"The validation set is small (n={len(y_val)}). " f"Small validation sets may lead to unreliable importance estimates. " - f"Consider decreasing jackknife_fraction (currently {jackknife_fraction}), increasing resample_validation_fraction (currently {resample_validation_fraction}), " + f"Consider decreasing subsampling_fraction (currently {subsampling_fraction}), increasing resample_validation_fraction (currently {resample_validation_fraction}), " f"or using a larger dataset for more stable results.", UserWarning, stacklevel=3, @@ -216,7 +216,7 @@ def uncertainty_plot( y: str, x_range=None, mode="resample", - jackknife_fraction=0.8, + subsampling_fraction=0.8, iterations=100, confounders=None, colormap=None, @@ -231,9 +231,9 @@ def uncertainty_plot( :param x: Needs to be a numerical feature :param y: Categorical feature :param x_range: Either None (range will be selected automatically) or a tuple with min and max value for the x-axis - :param mode: Sampling method, either "resample" (bootstrap) or "jackknife" (default = "resample") - :param jackknife_fraction: Fraction of data to retain for each jackknife sample (default = 0.8) - :param iterations: Number of iterations for resampling or jackknife (default = 100) + :param mode: Sampling method, either "resample" (bootstrap) or "random_subsampling" (default = "resample") + :param subsampling_fraction: Fraction of data to retain for each random subsampling iteration (default = 0.8) + :param iterations: Number of iterations for resampling or random subsampling (default = 100) :param confounders: List of tuples with the feature and reference value e.g., [("BMI", 25)] will use a reference of 25 for plots :param colormap: Colormap to use for the plot, default is None in which case matplotlib's default will be used :param clf: Provide a different scikit-learn classifier for the function. Should implement the predict_proba() and fit(). If None a LogisticRegression will be used. @@ -253,7 +253,7 @@ def uncertainty_plot( y_reg, x_range, mode=mode, - jackknife_fraction=jackknife_fraction, + subsampling_fraction=subsampling_fraction, iterations=iterations, confounders=confounders, clf=clf, @@ -296,7 +296,7 @@ def feature_importance( x: str, y: str, mode="resample", - jackknife_fraction=0.8, + subsampling_fraction=0.8, resample_validation_fraction=0.2, iterations=100, confounders=None, @@ -304,7 +304,7 @@ def feature_importance( ): """ Estimates the importance of a feature in predicting class labels using permutation-based - feature importance with resampling or jackknife methods. Uses log loss (cross-entropy) as the + feature importance with resampling or random subsampling methods. Uses log loss (cross-entropy) as the performance metric, which evaluates the full predicted probability distribution rather than just hard class predictions. @@ -315,10 +315,10 @@ def feature_importance( :param data: The input dataframe containing all features and target variable. :param x: The name of the feature to analyze for importance. :param y: The name of the target variable. - :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "jackknife". - :param jackknife_fraction: Fraction of data to keep in each jackknife iteration (only used if mode="jackknife"). + :param mode: Method for uncertainty estimation. Either "resample" (bootstrap) or "random_subsampling". + :param subsampling_fraction: Fraction of data to keep in each random subsampling iteration (only used if mode="random_subsampling"). :param resample_validation_fraction: Fraction of data to use for validation in resampling mode (only used if mode="resample"). - :param iterations: Number of resampling or jackknife iterations. + :param iterations: Number of resampling or random subsampling iterations. :param confounders: List of tuples (feature, reference value) pairs representing confounder features and their reference values. :param clf: Classifier to use for fitting. If None, uses LogisticRegression. :return: Dictionary containing feature importance statistics including mean importance, confidence intervals, validation/permuted log loss statistics, and significance metrics. @@ -347,7 +347,7 @@ def feature_importance( X_reg=X_reg, y_reg=y_reg, mode=mode, - jackknife_fraction=jackknife_fraction, + subsampling_fraction=subsampling_fraction, resample_validation_fraction=resample_validation_fraction, iterations=iterations, clf=clf, diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index be1ab85..e6ec893 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -2,7 +2,7 @@ Comprehensive tests for lorepy.uncertainty module. Tests cover: -- _get_uncertainty_data: uncertainty estimation via resampling/jackknife +- _get_uncertainty_data: uncertainty estimation via resampling/random subsampling - _get_feature_importance: feature importance calculation - uncertainty_plot: main uncertainty visualization function - feature_importance: public API for feature importance @@ -50,8 +50,8 @@ def test_basic_output_structure_resample(self, binary_sample_data): # Check long_df structure assert isinstance(long_df, pd.DataFrame) - def test_basic_output_structure_jackknife(self, binary_sample_data): - """Test basic output structure with jackknife mode.""" + def test_basic_output_structure_random_subsampling(self, binary_sample_data): + """Test basic output structure with random_subsampling mode.""" X_reg, y_reg, x_range = _prepare_data(binary_sample_data, "x", "y", []) output, long_df = _get_uncertainty_data( @@ -59,8 +59,8 @@ def test_basic_output_structure_jackknife(self, binary_sample_data): X_reg, y_reg, x_range, - mode="jackknife", - jackknife_fraction=0.8, + mode="random_subsampling", + subsampling_fraction=0.8, iterations=10, ) @@ -253,15 +253,15 @@ def test_mode_resample(self, binary_sample_data): assert result["mode"] == "resample" - def test_mode_jackknife(self, binary_sample_data): - """Test jackknife mode.""" + def test_mode_random_subsampling(self, binary_sample_data): + """Test random_subsampling mode.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) result = _get_feature_importance( - "x", X_reg, y_reg, mode="jackknife", iterations=10 + "x", X_reg, y_reg, mode="random_subsampling", iterations=10 ) - assert result["mode"] == "jackknife" + assert result["mode"] == "random_subsampling" def test_invalid_mode_raises_error(self, binary_sample_data): """Test that invalid mode raises NotImplementedError.""" @@ -316,17 +316,17 @@ def test_custom_classifier(self, binary_sample_data): assert isinstance(result["mean_importance"], float) - def test_jackknife_fraction_parameter(self, binary_sample_data): - """Test that jackknife_fraction parameter is used.""" + def test_subsampling_fraction_parameter(self, binary_sample_data): + """Test that subsampling_fraction parameter is used.""" X_reg, y_reg, _ = _prepare_data(binary_sample_data, "x", "y", []) # Different fractions should produce different results result_80 = _get_feature_importance( - "x", X_reg, y_reg, mode="jackknife", jackknife_fraction=0.8, iterations=10 + "x", X_reg, y_reg, mode="random_subsampling", subsampling_fraction=0.8, iterations=10 ) result_50 = _get_feature_importance( - "x", X_reg, y_reg, mode="jackknife", jackknife_fraction=0.5, iterations=10 + "x", X_reg, y_reg, mode="random_subsampling", subsampling_fraction=0.5, iterations=10 ) # Both should produce valid results @@ -387,10 +387,10 @@ def test_custom_x_range(self, binary_sample_data): assert xlim == custom_range plt.close() - def test_jackknife_mode(self, binary_sample_data): - """Test jackknife mode.""" + def test_random_subsampling_mode(self, binary_sample_data): + """Test random_subsampling mode.""" fig, axs = uncertainty_plot( - binary_sample_data, "x", "y", mode="jackknife", iterations=10 + binary_sample_data, "x", "y", mode="random_subsampling", iterations=10 ) assert len(axs) == 2 @@ -508,16 +508,16 @@ def test_with_custom_classifier(self, binary_sample_data): assert isinstance(result["mean_importance"], float) - def test_jackknife_mode(self, binary_sample_data): - """Test feature_importance with jackknife mode.""" + def test_random_subsampling_mode(self, binary_sample_data): + """Test feature_importance with random_subsampling mode.""" result = feature_importance( - binary_sample_data, x="x", y="y", mode="jackknife", iterations=10 + binary_sample_data, x="x", y="y", mode="random_subsampling", iterations=10 ) - assert result["mode"] == "jackknife" + assert result["mode"] == "random_subsampling" def test_small_validation_set_warning(self): - """Test warning for small validation sets in jackknife mode.""" + """Test warning for small validation sets in random_subsampling mode.""" # Use balanced classes to ensure both classes appear in validation splits, # which is required for log_loss scoring small_data = pd.DataFrame( @@ -532,12 +532,12 @@ def test_small_validation_set_warning(self): small_data, x="x", y="y", - mode="jackknife", - jackknife_fraction=0.8, + mode="random_subsampling", + subsampling_fraction=0.8, iterations=5, ) - assert result["mode"] == "jackknife" + assert result["mode"] == "random_subsampling" with pytest.warns(UserWarning, match="The validation set is small"): result = feature_importance( @@ -569,11 +569,11 @@ def test_no_warning_adequate_validation(self): large_data, x="x", y="y", - mode="jackknife", - jackknife_fraction=0.8, + mode="random_subsampling", + subsampling_fraction=0.8, iterations=5, ) - assert result["mode"] == "jackknife" + assert result["mode"] == "random_subsampling" except UserWarning as e: if "small" in str(e).lower(): pytest.fail("Unexpected small validation warning") @@ -623,28 +623,28 @@ def test_very_few_iterations(self, binary_sample_data): assert len(axs) == 2 plt.close() - def test_high_jackknife_fraction(self, binary_sample_data): - """Test with high jackknife fraction.""" + def test_high_subsampling_fraction(self, binary_sample_data): + """Test with high subsampling fraction.""" fig, axs = uncertainty_plot( binary_sample_data, "x", "y", - mode="jackknife", - jackknife_fraction=0.95, + mode="random_subsampling", + subsampling_fraction=0.95, iterations=10, ) assert len(axs) == 2 plt.close() - def test_low_jackknife_fraction(self, binary_sample_data): - """Test with low jackknife fraction.""" + def test_low_subsampling_fraction(self, binary_sample_data): + """Test with low subsampling fraction.""" fig, axs = uncertainty_plot( binary_sample_data, "x", "y", - mode="jackknife", - jackknife_fraction=0.5, + mode="random_subsampling", + subsampling_fraction=0.5, iterations=10, ) From 1beef310128c9e77d7ed6515f4ba38b5b7f02fcd Mon Sep 17 00:00:00 2001 From: autoblack_push <${GITHUB_ACTOR}@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:37:18 +0000 Subject: [PATCH 11/11] fixup! Format Python code with psf/black push --- example_uncertainty.py | 6 +++++- tests/test_uncertainty.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/example_uncertainty.py b/example_uncertainty.py index 28b8ef9..b6a4342 100644 --- a/example_uncertainty.py +++ b/example_uncertainty.py @@ -23,7 +23,11 @@ print(stats) stats = feature_importance( - data=iris_df, x="sepal width (cm)", y="species", iterations=100, mode="random_subsampling" + data=iris_df, + x="sepal width (cm)", + y="species", + iterations=100, + mode="random_subsampling", ) print(stats) diff --git a/tests/test_uncertainty.py b/tests/test_uncertainty.py index e6ec893..2139b0f 100644 --- a/tests/test_uncertainty.py +++ b/tests/test_uncertainty.py @@ -322,11 +322,21 @@ def test_subsampling_fraction_parameter(self, binary_sample_data): # Different fractions should produce different results result_80 = _get_feature_importance( - "x", X_reg, y_reg, mode="random_subsampling", subsampling_fraction=0.8, iterations=10 + "x", + X_reg, + y_reg, + mode="random_subsampling", + subsampling_fraction=0.8, + iterations=10, ) result_50 = _get_feature_importance( - "x", X_reg, y_reg, mode="random_subsampling", subsampling_fraction=0.5, iterations=10 + "x", + X_reg, + y_reg, + mode="random_subsampling", + subsampling_fraction=0.5, + iterations=10, ) # Both should produce valid results