diff --git a/src/main/java/io/github/yasmramos/mindforge/interpret/LIME.java b/src/main/java/io/github/yasmramos/mindforge/interpret/LIME.java new file mode 100644 index 0000000..797fbb2 --- /dev/null +++ b/src/main/java/io/github/yasmramos/mindforge/interpret/LIME.java @@ -0,0 +1,245 @@ +package io.github.yasmramos.mindforge.interpret; + +import java.io.Serializable; +import java.util.*; +import java.util.function.Function; + +/** + * LIME (Local Interpretable Model-agnostic Explanations). + * Explains individual predictions by approximating the model locally with an interpretable model. + */ +public class LIME implements Serializable { + private static final long serialVersionUID = 1L; + + private final int nSamples; + private final double kernelWidth; + private final long seed; + private double[] featureStd; + + private LIME(Builder builder) { + this.nSamples = builder.nSamples; + this.kernelWidth = builder.kernelWidth; + this.seed = builder.seed; + } + + /** + * Set feature standard deviations for perturbation scaling. + * @param std Standard deviation for each feature + */ + public void setFeatureStd(double[] std) { + this.featureStd = std; + } + + /** + * Calculate feature standard deviations from data. + * @param data Training data + */ + public void fitFeatureStd(double[][] data) { + int nFeatures = data[0].length; + featureStd = new double[nFeatures]; + + for (int j = 0; j < nFeatures; j++) { + double mean = 0; + for (double[] row : data) { + mean += row[j]; + } + mean /= data.length; + + double variance = 0; + for (double[] row : data) { + variance += (row[j] - mean) * (row[j] - mean); + } + featureStd[j] = Math.sqrt(variance / data.length); + if (featureStd[j] < 1e-10) featureStd[j] = 1.0; + } + } + + /** + * Explain a prediction using LIME. + * @param predict Prediction function + * @param instance Instance to explain + * @return Explanation containing feature weights + */ + public Explanation explain(Function predict, double[] instance) { + if (featureStd == null) { + throw new IllegalStateException("Feature standard deviations must be set. Call fitFeatureStd() first."); + } + + int nFeatures = instance.length; + Random random = new Random(seed); + + double[][] perturbedData = new double[nSamples][nFeatures]; + double[] predictions = new double[nSamples]; + double[] weights = new double[nSamples]; + + // Generate perturbed samples + for (int i = 0; i < nSamples; i++) { + double distance = 0; + for (int j = 0; j < nFeatures; j++) { + double perturbation = random.nextGaussian() * featureStd[j]; + perturbedData[i][j] = instance[j] + perturbation; + distance += (perturbation / featureStd[j]) * (perturbation / featureStd[j]); + } + + predictions[i] = predict.apply(perturbedData[i]); + + // Exponential kernel weight + weights[i] = Math.exp(-distance / (2 * kernelWidth * kernelWidth)); + } + + // Fit weighted linear regression + double[] coefficients = fitWeightedLinearRegression(perturbedData, predictions, weights, instance); + + // Calculate intercept + double intercept = predict.apply(instance); + for (int j = 0; j < nFeatures; j++) { + intercept -= coefficients[j] * instance[j]; + } + + return new Explanation(coefficients, intercept, predict.apply(instance)); + } + + private double[] fitWeightedLinearRegression(double[][] X, double[] y, double[] weights, double[] center) { + int n = X.length; + int p = X[0].length; + + // Center the features + double[][] Xc = new double[n][p]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < p; j++) { + Xc[i][j] = X[i][j] - center[j]; + } + } + + // Weighted mean of y + double yMean = 0, wSum = 0; + for (int i = 0; i < n; i++) { + yMean += weights[i] * y[i]; + wSum += weights[i]; + } + yMean /= wSum; + + // Weighted least squares + double[][] XtWX = new double[p][p]; + double[] XtWy = new double[p]; + + for (int i = 0; i < p; i++) { + for (int j = 0; j < p; j++) { + double sum = 0; + for (int k = 0; k < n; k++) { + sum += Xc[k][i] * weights[k] * Xc[k][j]; + } + XtWX[i][j] = sum; + } + + double sum = 0; + for (int k = 0; k < n; k++) { + sum += Xc[k][i] * weights[k] * (y[k] - yMean); + } + XtWy[i] = sum; + } + + // Regularization + for (int i = 0; i < p; i++) { + XtWX[i][i] += 1e-6; + } + + return solveLinearSystem(XtWX, XtWy); + } + + private double[] solveLinearSystem(double[][] A, double[] b) { + int n = A.length; + double[][] augmented = new double[n][n + 1]; + + for (int i = 0; i < n; i++) { + System.arraycopy(A[i], 0, augmented[i], 0, n); + augmented[i][n] = b[i]; + } + + for (int i = 0; i < n; i++) { + int maxRow = i; + for (int k = i + 1; k < n; k++) { + if (Math.abs(augmented[k][i]) > Math.abs(augmented[maxRow][i])) { + maxRow = k; + } + } + double[] temp = augmented[i]; + augmented[i] = augmented[maxRow]; + augmented[maxRow] = temp; + + if (Math.abs(augmented[i][i]) < 1e-10) continue; + + for (int k = i + 1; k < n; k++) { + double factor = augmented[k][i] / augmented[i][i]; + for (int j = i; j <= n; j++) { + augmented[k][j] -= factor * augmented[i][j]; + } + } + } + + double[] x = new double[n]; + for (int i = n - 1; i >= 0; i--) { + x[i] = augmented[i][n]; + for (int j = i + 1; j < n; j++) { + x[i] -= augmented[i][j] * x[j]; + } + if (Math.abs(augmented[i][i]) > 1e-10) { + x[i] /= augmented[i][i]; + } + } + + return x; + } + + /** + * Container for LIME explanation results. + */ + public static class Explanation implements Serializable { + private static final long serialVersionUID = 1L; + + private final double[] featureWeights; + private final double intercept; + private final double prediction; + + public Explanation(double[] featureWeights, double intercept, double prediction) { + this.featureWeights = featureWeights; + this.intercept = intercept; + this.prediction = prediction; + } + + public double[] getFeatureWeights() { return featureWeights.clone(); } + public double getIntercept() { return intercept; } + public double getPrediction() { return prediction; } + + /** + * Get top contributing features. + * @param n Number of top features + * @return Indices of top features sorted by absolute weight + */ + public int[] getTopFeatures(int n) { + Integer[] indices = new Integer[featureWeights.length]; + for (int i = 0; i < indices.length; i++) indices[i] = i; + + Arrays.sort(indices, (a, b) -> + Double.compare(Math.abs(featureWeights[b]), Math.abs(featureWeights[a]))); + + int[] result = new int[Math.min(n, indices.length)]; + for (int i = 0; i < result.length; i++) { + result[i] = indices[i]; + } + return result; + } + } + + public static class Builder { + private int nSamples = 1000; + private double kernelWidth = 0.75; + private long seed = 42; + + public Builder nSamples(int nSamples) { this.nSamples = nSamples; return this; } + public Builder kernelWidth(double width) { this.kernelWidth = width; return this; } + public Builder seed(long seed) { this.seed = seed; return this; } + + public LIME build() { return new LIME(this); } + } +} diff --git a/src/main/java/io/github/yasmramos/mindforge/interpret/SHAP.java b/src/main/java/io/github/yasmramos/mindforge/interpret/SHAP.java new file mode 100644 index 0000000..e9b79f6 --- /dev/null +++ b/src/main/java/io/github/yasmramos/mindforge/interpret/SHAP.java @@ -0,0 +1,265 @@ +package io.github.yasmramos.mindforge.interpret; + +import java.io.Serializable; +import java.util.*; +import java.util.function.Function; + +/** + * SHAP (SHapley Additive exPlanations) for model interpretability. + * Computes feature importance based on Shapley values from cooperative game theory. + * + * Implements Kernel SHAP, a model-agnostic approximation method. + */ +public class SHAP implements Serializable { + private static final long serialVersionUID = 1L; + + private final int nSamples; + private final long seed; + private double[][] backgroundData; + private double expectedValue; + + private SHAP(Builder builder) { + this.nSamples = builder.nSamples; + this.seed = builder.seed; + } + + /** + * Set background data for computing expected values. + * @param background Representative samples from training data + */ + public void setBackground(double[][] background) { + this.backgroundData = background; + } + + /** + * Compute SHAP values for a single instance. + * @param predict Prediction function that takes features and returns prediction + * @param instance The instance to explain + * @return SHAP values for each feature + */ + public double[] explain(Function predict, double[] instance) { + if (backgroundData == null || backgroundData.length == 0) { + throw new IllegalStateException("Background data must be set before explaining"); + } + + int nFeatures = instance.length; + Random random = new Random(seed); + + // Calculate expected value using background data + expectedValue = 0; + for (double[] bg : backgroundData) { + expectedValue += predict.apply(bg); + } + expectedValue /= backgroundData.length; + + // Kernel SHAP approximation using sampling + double[] shapValues = new double[nFeatures]; + double[] weights = new double[nSamples]; + double[][] coalitions = new double[nSamples][nFeatures]; + double[] predictions = new double[nSamples]; + + for (int s = 0; s < nSamples; s++) { + // Sample a coalition (subset of features) + boolean[] coalition = sampleCoalition(nFeatures, random); + int coalitionSize = countTrue(coalition); + + // Calculate Shapley kernel weight + weights[s] = shapleyKernelWeight(nFeatures, coalitionSize); + + // Create masked instance + double[] maskedInstance = createMaskedInstance(instance, coalition, random); + + // Store coalition as binary vector + for (int j = 0; j < nFeatures; j++) { + coalitions[s][j] = coalition[j] ? 1.0 : 0.0; + } + + // Get prediction for masked instance + predictions[s] = predict.apply(maskedInstance); + } + + // Solve weighted linear regression: predictions = coalitions * shapValues + expectedValue + shapValues = solveWeightedRegression(coalitions, predictions, weights, expectedValue); + + return shapValues; + } + + /** + * Compute SHAP values for multiple instances. + * @param predict Prediction function + * @param instances Instances to explain + * @return SHAP values matrix [n_instances, n_features] + */ + public double[][] explainBatch(Function predict, double[][] instances) { + double[][] shapValues = new double[instances.length][]; + for (int i = 0; i < instances.length; i++) { + shapValues[i] = explain(predict, instances[i]); + } + return shapValues; + } + + /** + * Compute mean absolute SHAP values (global feature importance). + * @param predict Prediction function + * @param instances Instances to analyze + * @return Mean absolute SHAP value for each feature + */ + public double[] meanAbsoluteShap(Function predict, double[][] instances) { + double[][] allShap = explainBatch(predict, instances); + int nFeatures = allShap[0].length; + double[] meanAbs = new double[nFeatures]; + + for (int j = 0; j < nFeatures; j++) { + double sum = 0; + for (double[] shap : allShap) { + sum += Math.abs(shap[j]); + } + meanAbs[j] = sum / allShap.length; + } + + return meanAbs; + } + + private boolean[] sampleCoalition(int nFeatures, Random random) { + boolean[] coalition = new boolean[nFeatures]; + for (int i = 0; i < nFeatures; i++) { + coalition[i] = random.nextBoolean(); + } + return coalition; + } + + private int countTrue(boolean[] arr) { + int count = 0; + for (boolean b : arr) if (b) count++; + return count; + } + + private double shapleyKernelWeight(int M, int s) { + if (s == 0 || s == M) { + return 1e6; // Large weight for empty and full coalitions + } + // Shapley kernel: (M-1) / (C(M,s) * s * (M-s)) + double binomial = binomialCoefficient(M, s); + return (M - 1.0) / (binomial * s * (M - s)); + } + + private double binomialCoefficient(int n, int k) { + if (k > n - k) k = n - k; + double result = 1; + for (int i = 0; i < k; i++) { + result = result * (n - i) / (i + 1); + } + return result; + } + + private double[] createMaskedInstance(double[] instance, boolean[] coalition, Random random) { + double[] masked = new double[instance.length]; + int bgIdx = random.nextInt(backgroundData.length); + + for (int i = 0; i < instance.length; i++) { + if (coalition[i]) { + masked[i] = instance[i]; + } else { + masked[i] = backgroundData[bgIdx][i]; + } + } + return masked; + } + + private double[] solveWeightedRegression(double[][] X, double[] y, double[] weights, double intercept) { + int n = X.length; + int p = X[0].length; + + // Adjust y for intercept + double[] yAdj = new double[n]; + for (int i = 0; i < n; i++) { + yAdj[i] = y[i] - intercept; + } + + // Weighted least squares: (X'WX)^-1 * X'Wy + double[][] XtWX = new double[p][p]; + double[] XtWy = new double[p]; + + for (int i = 0; i < p; i++) { + for (int j = 0; j < p; j++) { + double sum = 0; + for (int k = 0; k < n; k++) { + sum += X[k][i] * weights[k] * X[k][j]; + } + XtWX[i][j] = sum; + } + + double sum = 0; + for (int k = 0; k < n; k++) { + sum += X[k][i] * weights[k] * yAdj[k]; + } + XtWy[i] = sum; + } + + // Add regularization for numerical stability + for (int i = 0; i < p; i++) { + XtWX[i][i] += 1e-6; + } + + // Solve using Gaussian elimination + return solveLinearSystem(XtWX, XtWy); + } + + private double[] solveLinearSystem(double[][] A, double[] b) { + int n = A.length; + double[][] augmented = new double[n][n + 1]; + + for (int i = 0; i < n; i++) { + System.arraycopy(A[i], 0, augmented[i], 0, n); + augmented[i][n] = b[i]; + } + + // Forward elimination + for (int i = 0; i < n; i++) { + int maxRow = i; + for (int k = i + 1; k < n; k++) { + if (Math.abs(augmented[k][i]) > Math.abs(augmented[maxRow][i])) { + maxRow = k; + } + } + double[] temp = augmented[i]; + augmented[i] = augmented[maxRow]; + augmented[maxRow] = temp; + + if (Math.abs(augmented[i][i]) < 1e-10) continue; + + for (int k = i + 1; k < n; k++) { + double factor = augmented[k][i] / augmented[i][i]; + for (int j = i; j <= n; j++) { + augmented[k][j] -= factor * augmented[i][j]; + } + } + } + + // Back substitution + double[] x = new double[n]; + for (int i = n - 1; i >= 0; i--) { + x[i] = augmented[i][n]; + for (int j = i + 1; j < n; j++) { + x[i] -= augmented[i][j] * x[j]; + } + if (Math.abs(augmented[i][i]) > 1e-10) { + x[i] /= augmented[i][i]; + } + } + + return x; + } + + public double getExpectedValue() { return expectedValue; } + + public static class Builder { + private int nSamples = 100; + private long seed = 42; + + public Builder nSamples(int nSamples) { this.nSamples = nSamples; return this; } + public Builder seed(long seed) { this.seed = seed; return this; } + + public SHAP build() { return new SHAP(this); } + } +} diff --git a/src/main/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimization.java b/src/main/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimization.java new file mode 100644 index 0000000..41b6ef4 --- /dev/null +++ b/src/main/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimization.java @@ -0,0 +1,303 @@ +package io.github.yasmramos.mindforge.model_selection; + +import java.io.Serializable; +import java.util.*; +import java.util.function.Function; + +/** + * Bayesian Optimization for hyperparameter tuning. + * Uses Gaussian Process surrogate model with Expected Improvement acquisition function. + */ +public class BayesianOptimization implements Serializable { + private static final long serialVersionUID = 1L; + + private final int nIterations; + private final int nInitialPoints; + private final double explorationWeight; + private final long seed; + + private Map parameterBounds; + private List observedPoints; + private List observedValues; + private double[] bestParams; + private double bestValue; + private List paramNames; + + private BayesianOptimization(Builder builder) { + this.nIterations = builder.nIterations; + this.nInitialPoints = builder.nInitialPoints; + this.explorationWeight = builder.explorationWeight; + this.seed = builder.seed; + } + + /** + * Set parameter search space. + * @param bounds Map of parameter name to [min, max] bounds + */ + public void setParameterBounds(Map bounds) { + this.parameterBounds = new LinkedHashMap<>(bounds); + this.paramNames = new ArrayList<>(bounds.keySet()); + } + + /** + * Run optimization. + * @param objectiveFunction Function that takes parameters and returns score (higher is better) + * @return Best parameters found + */ + public Map optimize(Function, Double> objectiveFunction) { + if (parameterBounds == null || parameterBounds.isEmpty()) { + throw new IllegalStateException("Parameter bounds must be set before optimization"); + } + + Random random = new Random(seed); + int nDims = parameterBounds.size(); + + observedPoints = new ArrayList<>(); + observedValues = new ArrayList<>(); + bestValue = Double.NEGATIVE_INFINITY; + + // Initial random sampling + for (int i = 0; i < nInitialPoints; i++) { + double[] point = sampleRandomPoint(random); + double value = evaluatePoint(point, objectiveFunction); + + observedPoints.add(point); + observedValues.add(value); + + if (value > bestValue) { + bestValue = value; + bestParams = point.clone(); + } + } + + // Bayesian optimization loop + for (int iter = 0; iter < nIterations; iter++) { + double[] nextPoint = findNextPoint(random); + double value = evaluatePoint(nextPoint, objectiveFunction); + + observedPoints.add(nextPoint); + observedValues.add(value); + + if (value > bestValue) { + bestValue = value; + bestParams = nextPoint.clone(); + } + } + + return arrayToMap(bestParams); + } + + private double[] sampleRandomPoint(Random random) { + double[] point = new double[paramNames.size()]; + for (int i = 0; i < paramNames.size(); i++) { + double[] bounds = parameterBounds.get(paramNames.get(i)); + point[i] = bounds[0] + random.nextDouble() * (bounds[1] - bounds[0]); + } + return point; + } + + private double evaluatePoint(double[] point, Function, Double> objective) { + return objective.apply(arrayToMap(point)); + } + + private Map arrayToMap(double[] point) { + Map params = new LinkedHashMap<>(); + for (int i = 0; i < paramNames.size(); i++) { + params.put(paramNames.get(i), point[i]); + } + return params; + } + + private double[] findNextPoint(Random random) { + int nCandidates = 1000; + double[] bestCandidate = null; + double bestAcquisition = Double.NEGATIVE_INFINITY; + + for (int i = 0; i < nCandidates; i++) { + double[] candidate = sampleRandomPoint(random); + double acquisition = computeExpectedImprovement(candidate); + + if (acquisition > bestAcquisition) { + bestAcquisition = acquisition; + bestCandidate = candidate; + } + } + + return bestCandidate; + } + + private double computeExpectedImprovement(double[] point) { + double[] prediction = predictGP(point); + double mean = prediction[0]; + double std = prediction[1]; + + if (std < 1e-10) return 0; + + double improvement = mean - bestValue - explorationWeight; + double z = improvement / std; + + // Expected improvement: std * (z * Phi(z) + phi(z)) + double phi = normalPDF(z); + double bigPhi = normalCDF(z); + + return std * (z * bigPhi + phi); + } + + private double[] predictGP(double[] point) { + int n = observedPoints.size(); + if (n == 0) return new double[]{0, 1}; + + // Compute kernel matrix K and kernel vector k* + double[][] K = new double[n][n]; + double[] kStar = new double[n]; + + double lengthScale = estimateLengthScale(); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + K[i][j] = rbfKernel(observedPoints.get(i), observedPoints.get(j), lengthScale); + } + K[i][i] += 1e-6; // Noise term for numerical stability + kStar[i] = rbfKernel(observedPoints.get(i), point, lengthScale); + } + + // Solve K * alpha = y + double[] y = new double[n]; + for (int i = 0; i < n; i++) { + y[i] = observedValues.get(i); + } + + double[] alpha = solveLinearSystem(K, y); + + // Mean prediction: k* . alpha + double mean = 0; + for (int i = 0; i < n; i++) { + mean += kStar[i] * alpha[i]; + } + + // Variance: k** - k* . K^-1 . k* + double kStarStar = rbfKernel(point, point, lengthScale); + double[] KInvKStar = solveLinearSystem(K, kStar); + + double variance = kStarStar; + for (int i = 0; i < n; i++) { + variance -= kStar[i] * KInvKStar[i]; + } + variance = Math.max(variance, 1e-10); + + return new double[]{mean, Math.sqrt(variance)}; + } + + private double rbfKernel(double[] x1, double[] x2, double lengthScale) { + double dist = 0; + for (int i = 0; i < x1.length; i++) { + double diff = (x1[i] - x2[i]) / lengthScale; + dist += diff * diff; + } + return Math.exp(-0.5 * dist); + } + + private double estimateLengthScale() { + if (observedPoints.size() < 2) return 1.0; + + double sumDist = 0; + int count = 0; + for (int i = 0; i < observedPoints.size(); i++) { + for (int j = i + 1; j < observedPoints.size(); j++) { + double dist = 0; + for (int k = 0; k < paramNames.size(); k++) { + double[] bounds = parameterBounds.get(paramNames.get(k)); + double range = bounds[1] - bounds[0]; + double diff = (observedPoints.get(i)[k] - observedPoints.get(j)[k]) / range; + dist += diff * diff; + } + sumDist += Math.sqrt(dist); + count++; + } + } + return sumDist / count; + } + + private double[] solveLinearSystem(double[][] A, double[] b) { + int n = A.length; + double[][] augmented = new double[n][n + 1]; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + augmented[i][j] = A[i][j]; + } + augmented[i][n] = b[i]; + } + + for (int i = 0; i < n; i++) { + int maxRow = i; + for (int k = i + 1; k < n; k++) { + if (Math.abs(augmented[k][i]) > Math.abs(augmented[maxRow][i])) { + maxRow = k; + } + } + double[] temp = augmented[i]; + augmented[i] = augmented[maxRow]; + augmented[maxRow] = temp; + + if (Math.abs(augmented[i][i]) < 1e-10) continue; + + for (int k = i + 1; k < n; k++) { + double factor = augmented[k][i] / augmented[i][i]; + for (int j = i; j <= n; j++) { + augmented[k][j] -= factor * augmented[i][j]; + } + } + } + + double[] x = new double[n]; + for (int i = n - 1; i >= 0; i--) { + x[i] = augmented[i][n]; + for (int j = i + 1; j < n; j++) { + x[i] -= augmented[i][j] * x[j]; + } + if (Math.abs(augmented[i][i]) > 1e-10) { + x[i] /= augmented[i][i]; + } + } + + return x; + } + + private double normalPDF(double x) { + return Math.exp(-0.5 * x * x) / Math.sqrt(2 * Math.PI); + } + + private double normalCDF(double x) { + return 0.5 * (1 + erf(x / Math.sqrt(2))); + } + + private double erf(double x) { + double t = 1.0 / (1.0 + 0.5 * Math.abs(x)); + double tau = t * Math.exp(-x * x - 1.26551223 + + t * (1.00002368 + t * (0.37409196 + t * (0.09678418 + + t * (-0.18628806 + t * (0.27886807 + t * (-1.13520398 + + t * (1.48851587 + t * (-0.82215223 + t * 0.17087277))))))))); + return x >= 0 ? 1 - tau : tau - 1; + } + + // Getters + public double[] getBestParams() { return bestParams != null ? bestParams.clone() : null; } + public double getBestValue() { return bestValue; } + public Map getBestParamsAsMap() { return bestParams != null ? arrayToMap(bestParams) : null; } + public List getObservedValues() { return new ArrayList<>(observedValues); } + + public static class Builder { + private int nIterations = 50; + private int nInitialPoints = 5; + private double explorationWeight = 0.01; + private long seed = 42; + + public Builder nIterations(int n) { this.nIterations = n; return this; } + public Builder nInitialPoints(int n) { this.nInitialPoints = n; return this; } + public Builder explorationWeight(double w) { this.explorationWeight = w; return this; } + public Builder seed(long seed) { this.seed = seed; return this; } + + public BayesianOptimization build() { return new BayesianOptimization(this); } + } +} diff --git a/src/test/java/io/github/yasmramos/mindforge/interpret/SHAPLIMETest.java b/src/test/java/io/github/yasmramos/mindforge/interpret/SHAPLIMETest.java new file mode 100644 index 0000000..0d17c1e --- /dev/null +++ b/src/test/java/io/github/yasmramos/mindforge/interpret/SHAPLIMETest.java @@ -0,0 +1,212 @@ +package io.github.yasmramos.mindforge.interpret; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.function.Function; + +/** + * Tests for SHAP and LIME interpretability classes. + */ +public class SHAPLIMETest { + + // Simple linear model for testing: y = 2*x0 + 3*x1 + 1 + private Function linearModel = x -> 2.0 * x[0] + 3.0 * x[1] + 1.0; + + @Test + public void testSHAPExplain() { + double[][] background = { + {0, 0}, {1, 0}, {0, 1}, {1, 1}, + {0.5, 0.5}, {0.2, 0.8}, {0.8, 0.2} + }; + + SHAP shap = new SHAP.Builder() + .nSamples(200) + .seed(42) + .build(); + + shap.setBackground(background); + + double[] instance = {1.0, 1.0}; + double[] shapValues = shap.explain(linearModel, instance); + + assertNotNull(shapValues); + assertEquals(2, shapValues.length); + + // SHAP values should roughly correspond to feature contributions + // For linear model, x0 contributes ~2 and x1 contributes ~3 + assertTrue(shapValues[1] > shapValues[0]); // x1 should have higher contribution + } + + @Test + public void testSHAPExplainBatch() { + double[][] background = {{0, 0}, {1, 1}, {0.5, 0.5}}; + + SHAP shap = new SHAP.Builder() + .nSamples(100) + .build(); + + shap.setBackground(background); + + double[][] instances = {{1.0, 0.0}, {0.0, 1.0}}; + double[][] shapValues = shap.explainBatch(linearModel, instances); + + assertNotNull(shapValues); + assertEquals(2, shapValues.length); + assertEquals(2, shapValues[0].length); + } + + @Test + public void testSHAPMeanAbsolute() { + double[][] background = {{0, 0}, {1, 1}}; + + SHAP shap = new SHAP.Builder() + .nSamples(100) + .build(); + + shap.setBackground(background); + + double[][] instances = { + {1.0, 0.5}, {0.5, 1.0}, {0.8, 0.2} + }; + + double[] meanAbs = shap.meanAbsoluteShap(linearModel, instances); + + assertNotNull(meanAbs); + assertEquals(2, meanAbs.length); + assertTrue(meanAbs[0] >= 0); + assertTrue(meanAbs[1] >= 0); + } + + @Test + public void testSHAPExpectedValue() { + double[][] background = {{0, 0}, {1, 1}, {2, 2}}; + + SHAP shap = new SHAP.Builder().build(); + shap.setBackground(background); + + shap.explain(linearModel, new double[]{1, 1}); + + // Expected value should be the average prediction on background + // (1 + 6 + 11) / 3 = 6 + assertEquals(6.0, shap.getExpectedValue(), 0.001); + } + + @Test + public void testSHAPWithoutBackground() { + SHAP shap = new SHAP.Builder().build(); + + assertThrows(IllegalStateException.class, () -> + shap.explain(linearModel, new double[]{1, 1})); + } + + @Test + public void testLIMEExplain() { + double[][] trainingData = { + {0, 0}, {1, 0}, {0, 1}, {1, 1}, + {0.5, 0.5}, {0.2, 0.8}, {0.8, 0.2} + }; + + LIME lime = new LIME.Builder() + .nSamples(500) + .kernelWidth(0.5) + .seed(42) + .build(); + + lime.fitFeatureStd(trainingData); + + double[] instance = {0.5, 0.5}; + LIME.Explanation explanation = lime.explain(linearModel, instance); + + assertNotNull(explanation); + + double[] weights = explanation.getFeatureWeights(); + assertEquals(2, weights.length); + + // For linear model, LIME should recover approximately the true coefficients + // x0 coefficient ~2, x1 coefficient ~3 + assertTrue(weights[1] > weights[0]); + } + + @Test + public void testLIMEExplanationPrediction() { + double[][] trainingData = {{0, 0}, {1, 1}}; + + LIME lime = new LIME.Builder() + .nSamples(100) + .build(); + + lime.fitFeatureStd(trainingData); + + double[] instance = {1.0, 1.0}; + LIME.Explanation explanation = lime.explain(linearModel, instance); + + // Prediction should match the model output + assertEquals(linearModel.apply(instance), explanation.getPrediction(), 0.001); + } + + @Test + public void testLIMETopFeatures() { + double[][] trainingData = {{0, 0, 0}, {1, 1, 1}}; + + // Model where x2 has highest contribution + Function model = x -> x[0] + 2*x[1] + 5*x[2]; + + LIME lime = new LIME.Builder() + .nSamples(500) + .build(); + + lime.fitFeatureStd(trainingData); + + double[] instance = {1.0, 1.0, 1.0}; + LIME.Explanation explanation = lime.explain(model, instance); + + int[] topFeatures = explanation.getTopFeatures(2); + + assertEquals(2, topFeatures.length); + // x2 (index 2) should be the most important + assertEquals(2, topFeatures[0]); + } + + @Test + public void testLIMEWithoutFeatureStd() { + LIME lime = new LIME.Builder().build(); + + assertThrows(IllegalStateException.class, () -> + lime.explain(linearModel, new double[]{1, 1})); + } + + @Test + public void testLIMESetFeatureStd() { + LIME lime = new LIME.Builder() + .nSamples(100) + .build(); + + lime.setFeatureStd(new double[]{1.0, 1.0}); + + // Should not throw + LIME.Explanation explanation = lime.explain(linearModel, new double[]{0.5, 0.5}); + assertNotNull(explanation); + } + + @Test + public void testLIMEIntercept() { + double[][] trainingData = {{0, 0}, {1, 1}}; + + LIME lime = new LIME.Builder() + .nSamples(200) + .build(); + + lime.fitFeatureStd(trainingData); + + double[] instance = {0.5, 0.5}; + LIME.Explanation explanation = lime.explain(linearModel, instance); + + // Check that intercept + weights*instance ≈ prediction + double[] weights = explanation.getFeatureWeights(); + double reconstructed = explanation.getIntercept() + + weights[0] * instance[0] + weights[1] * instance[1]; + + assertEquals(explanation.getPrediction(), reconstructed, 0.5); + } +} diff --git a/src/test/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimizationTest.java b/src/test/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimizationTest.java new file mode 100644 index 0000000..ed97b58 --- /dev/null +++ b/src/test/java/io/github/yasmramos/mindforge/model_selection/BayesianOptimizationTest.java @@ -0,0 +1,200 @@ +package io.github.yasmramos.mindforge.model_selection; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.*; +import java.util.function.Function; + +/** + * Tests for Bayesian Optimization. + */ +public class BayesianOptimizationTest { + + @Test + public void testOptimizeSimpleFunction() { + // Optimize a simple quadratic: -(x-3)^2 + 10 + // Maximum at x=3 with value 10 + Function, Double> objective = params -> { + double x = params.get("x"); + return -(x - 3) * (x - 3) + 10; + }; + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(20) + .nInitialPoints(5) + .seed(42) + .build(); + + Map bounds = new HashMap<>(); + bounds.put("x", new double[]{0, 6}); + bo.setParameterBounds(bounds); + + Map bestParams = bo.optimize(objective); + + assertNotNull(bestParams); + assertTrue(bestParams.containsKey("x")); + + // Should find x close to 3 + assertEquals(3.0, bestParams.get("x"), 0.5); + + // Best value should be close to 10 + assertTrue(bo.getBestValue() > 9); + } + + @Test + public void testOptimizeTwoDimensional() { + // Optimize: -(x-2)^2 - (y-3)^2 + 20 + // Maximum at (2, 3) with value 20 + Function, Double> objective = params -> { + double x = params.get("x"); + double y = params.get("y"); + return -(x - 2) * (x - 2) - (y - 3) * (y - 3) + 20; + }; + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(30) + .nInitialPoints(10) + .seed(123) + .build(); + + Map bounds = new LinkedHashMap<>(); + bounds.put("x", new double[]{0, 5}); + bounds.put("y", new double[]{0, 5}); + bo.setParameterBounds(bounds); + + Map bestParams = bo.optimize(objective); + + assertNotNull(bestParams); + assertEquals(2, bestParams.size()); + + // Should be reasonably close to optimal + assertTrue(bo.getBestValue() > 18); + } + + @Test + public void testGetBestParamsAsMap() { + Function, Double> objective = params -> -params.get("x") * params.get("x"); + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(10) + .nInitialPoints(3) + .build(); + + Map bounds = new HashMap<>(); + bounds.put("x", new double[]{-5, 5}); + bo.setParameterBounds(bounds); + + bo.optimize(objective); + + Map bestMap = bo.getBestParamsAsMap(); + assertNotNull(bestMap); + assertTrue(bestMap.containsKey("x")); + } + + @Test + public void testGetObservedValues() { + Function, Double> objective = params -> params.get("x"); + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(5) + .nInitialPoints(3) + .build(); + + Map bounds = new HashMap<>(); + bounds.put("x", new double[]{0, 10}); + bo.setParameterBounds(bounds); + + bo.optimize(objective); + + List observed = bo.getObservedValues(); + assertNotNull(observed); + assertEquals(8, observed.size()); // 3 initial + 5 iterations + } + + @Test + public void testWithoutBounds() { + BayesianOptimization bo = new BayesianOptimization.Builder().build(); + + assertThrows(IllegalStateException.class, () -> + bo.optimize(params -> 0.0)); + } + + @Test + public void testExplorationWeight() { + Function, Double> objective = params -> { + double x = params.get("x"); + return -x * x; // Maximum at x=0 + }; + + // High exploration weight + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(10) + .nInitialPoints(3) + .explorationWeight(0.5) + .seed(42) + .build(); + + Map bounds = new HashMap<>(); + bounds.put("x", new double[]{-5, 5}); + bo.setParameterBounds(bounds); + + bo.optimize(objective); + + // Should still find the optimum + assertTrue(bo.getBestValue() > -1); + } + + @Test + public void testMultipleParameters() { + Function, Double> objective = params -> { + double a = params.get("a"); + double b = params.get("b"); + double c = params.get("c"); + return -(a - 1) * (a - 1) - (b - 2) * (b - 2) - (c - 3) * (c - 3); + }; + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(40) + .nInitialPoints(10) + .seed(42) + .build(); + + Map bounds = new LinkedHashMap<>(); + bounds.put("a", new double[]{0, 5}); + bounds.put("b", new double[]{0, 5}); + bounds.put("c", new double[]{0, 5}); + bo.setParameterBounds(bounds); + + Map best = bo.optimize(objective); + + assertEquals(3, best.size()); + assertTrue(bo.getBestValue() > -3); // Reasonably close to 0 + } + + @Test + public void testNoisyObjective() { + Random random = new Random(42); + + Function, Double> objective = params -> { + double x = params.get("x"); + return -(x - 5) * (x - 5) + random.nextGaussian() * 0.1; + }; + + BayesianOptimization bo = new BayesianOptimization.Builder() + .nIterations(20) + .nInitialPoints(5) + .seed(42) + .build(); + + Map bounds = new HashMap<>(); + bounds.put("x", new double[]{0, 10}); + bo.setParameterBounds(bounds); + + bo.optimize(objective); + + // Should find approximately x=5 despite noise + double bestX = bo.getBestParamsAsMap().get("x"); + assertEquals(5.0, bestX, 2.0); + } +}