Skip to content

mohitsah0/uttararisk-next

Repository files navigation

UttaraRisk-Next: Multi-Task Ensemble Learning for Maternal Health Risk Prediction

Python 3.10+ License: MIT Status: Research

Overview

UttaraRisk-Next is a multi-task ensemble learning framework for comprehensive maternal health risk assessment in Uttarakhand, India. The model simultaneously predicts:

  1. Risk Percentage (0-100%, continuous regression)
  2. Abortion Probability (binary classification with calibrated probabilities)
  3. Maternal Mortality Risk (binary classification with calibrated probabilities)

Key Features

  • Multi-task Learning: Shared feature engineering across 3 related outcomes
  • Excellent Calibration: ECE <0.025 for reliable probability estimates
  • Fairness-Aware: Comprehensive bias analysis across demographic groups
  • Resource-Efficient: 45MB model size, 2.1ms inference time
  • Clinical Utility: Designed for risk communication and shared decision-making

Performance Summary

Task Primary Metric Value Calibration (ECE)
Risk Prediction MAE / R² 5.34% / 0.725 N/A
Abortion Classification ROC-AUC / F1 0.565 / 0.013 0.020
Mortality Classification ROC-AUC / F1 0.540 / 0.997 0.001

Quick Start

Installation

# Clone the repository
git clone https://github.com/yourusername/uttararisk-next.git
cd uttararisk-next

# Create virtual environment
python3 -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

Requirements

numpy>=1.21.0
pandas>=1.3.0
scikit-learn>=1.0.0
matplotlib>=3.4.0
seaborn>=0.11.0
scipy>=1.7.0

Basic Usage

from uttararisk_final_model import UttaraRiskNextModel
import numpy as np

# Load preprocessed data
data = np.load('preprocessed_data.npz', allow_pickle=True)
X_train, y_train_risk = data['X_train'], data['y_train_risk']
y_train_abort, y_train_mort = data['y_train_abort'], data['y_train_mortality']

# Initialize and train model
model = UttaraRiskNextModel()
model.fit(X_train, y_train_risk, y_train_abort, y_train_mort)

# Make predictions
X_new = ...  # New patient data (78 features)
risk_pct, risk_lo, risk_hi = model.predict_risk(X_new)
abort_prob = model.predict_abortion(X_new)
mort_prob = model.predict_mortality(X_new)

print(f"Risk: {risk_pct:.1f}% [{risk_lo:.1f}%, {risk_hi:.1f}%]")
print(f"Abortion probability: {abort_prob:.1%}")
print(f"Mortality probability: {mort_prob:.3%}")

Running Experiments

Complete Experimental Pipeline

Run all experiments (F1-scores, 5-fold CV, baselines) with a single command:

python run_all_experiments.py

This script will:

  1. Calculate F1-scores and imbalance metrics on validation set
  2. Run 5-fold stratified cross-validation for both classification tasks
  3. Train 6 baseline regression models
  4. Train 7 baseline abortion classification models
  5. Train 4 baseline mortality classification models
  6. Save all results to JSON files

Output Files:

  • validation_metrics_detailed.json - Complete validation metrics
  • cross_validation_results.json - 5-fold CV results with fold-by-fold breakdown
  • baseline_regression_results.json - Regression baseline comparisons
  • baseline_abortion_results.json - Abortion classification baselines
  • baseline_mortality_results.json - Mortality classification baselines

Expected Runtime: ~5 minutes on standard CPU


5-Fold Cross-Validation

Running Cross-Validation

The run_all_experiments.py script includes comprehensive 5-fold stratified cross-validation. For standalone CV:

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.calibration import CalibratedClassifierCV
import numpy as np

# Load data
data = np.load('preprocessed_data.npz', allow_pickle=True)
X_train = data['X_train']
y_train_abort = data['y_train_abort']  # or y_train_mortality

# Initialize 5-fold stratified CV
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Storage for results
cv_results = {
    'roc_auc': [], 'f1': [], 'precision': [],
    'recall': [], 'balanced_acc': [], 'mcc': []
}

# Run CV
for fold, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train_abort), 1):
    print(f"Fold {fold}/5...")
    
    # Split data
    X_fold_train, X_fold_val = X_train[train_idx], X_train[val_idx]
    y_fold_train, y_fold_val = y_train_abort[train_idx], y_train_abort[val_idx]
    
    # Train model
    model = GradientBoostingClassifier(
        n_estimators=100, learning_rate=0.1, max_depth=5, random_state=42
    )
    model.fit(X_fold_train, y_fold_train)
    
    # Calibrate probabilities
    calibrated = CalibratedClassifierCV(model, method='isotonic', cv=3)
    calibrated.fit(X_fold_train, y_fold_train)
    
    # Predictions
    y_pred = calibrated.predict(X_fold_val)
    y_pred_proba = calibrated.predict_proba(X_fold_val)[:, 1]
    
    # Calculate metrics
    cv_results['roc_auc'].append(roc_auc_score(y_fold_val, y_pred_proba))
    cv_results['f1'].append(f1_score(y_fold_val, y_pred))
    cv_results['precision'].append(precision_score(y_fold_val, y_pred))
    cv_results['recall'].append(recall_score(y_fold_val, y_pred))
    # ... add other metrics

# Report results
for metric, values in cv_results.items():
    mean_val = np.mean(values)
    std_val = np.std(values)
    print(f"{metric}: {mean_val:.3f} ± {std_val:.3f}")

Cross-Validation Results

Abortion Classification (5-Fold CV):

ROC-AUC:          0.530 ± 0.021
PR-AUC:           0.348 ± 0.010
F1-Score:         0.020 ± 0.019
Precision:        0.500 ± 0.447
Recall:           0.010 ± 0.010
Balanced Acc:     0.504 ± 0.005
MCC:              0.051 ± 0.051
Brier Score:      0.207 ± 0.001

Mortality Classification (5-Fold CV):

ROC-AUC:          0.530 ± 0.205
PR-AUC:           0.994 ± 0.003
F1-Score:         0.997 ± 0.001
Precision:        0.994 ± 0.001
Recall:           1.000 ± 0.000
Balanced Acc:     0.500 ± 0.000
MCC:              0.000 ± 0.000
Brier Score:      0.006 ± 0.001

Key Observations:

  • Low standard deviations for calibration metrics (Brier ±0.001)
  • Stable ROC-AUC across folds (±0.021 for abortion)
  • High variance in precision due to class imbalance
  • Mortality task challenged by extreme rarity (0.6%)

Methodology

Data Preprocessing

78 Engineered Features from 22 Raw Variables:

  1. Clinical Categories (25 features):

    • Hemoglobin: Severe/Moderate/Mild anemia, Normal (WHO criteria)
    • Blood Pressure: Hypertensive/Stage1/Elevated/Normal (AHA guidelines)
    • BMI: Underweight/Normal/Overweight/Obese (WHO classification)
    • Age: Teen/Optimal/Advanced maternal age
    • Gestational age: Extremely/Very/Moderate preterm, Term
  2. Geographic Encoding (13 features):

    • One-hot encoding for all 13 Uttarakhand districts
  3. Demographic Encoding (10 features):

    • Caste (4 categories), Education (4 levels), Rural/Urban (2)
  4. Composite Indicators (2 features):

    • Vulnerability score (0-5 scale)
    • High-risk pregnancy flag
  5. Missingness Flags (16 features):

    • Binary indicators for variables with >5% missing data
  6. Original Continuous (9 features):

    • Age, Hemoglobin, BP, BMI, Gestational age, etc.
  7. Binary Clinical Flags (3 features):

    • Diabetes, Hypertension, Below poverty line

Run preprocessing:

python src/data_preprocessing.py

Model Architecture

Ensemble Design:

  • Risk Regression: 70% Gradient Boosting + 30% Random Forest
  • Abortion Classification: Gradient Boosting + Isotonic Calibration
  • Mortality Classification: Random Forest (rare-event optimized) + Isotonic Calibration

Hyperparameters:

# Gradient Boosting
n_estimators = 100
learning_rate = 0.1
max_depth = 5

# Random Forest
n_estimators = 100
max_depth = 10
class_weight = {0: 0.503, 1: 83.17}  # For mortality task

Training

# Train full model`
python src/uttararisk_final_model.py

# Output: trained_model.pkl, metrics.json, predictions.csv

Evaluation Metrics

Regression (Risk Prediction)

  • MAE: Mean Absolute Error (percentage points)
  • R²: Coefficient of determination
  • 90% Interval Coverage: Prediction interval reliability

Classification (Abortion & Mortality)

  • ROC-AUC: Area under ROC curve (discrimination)
  • PR-AUC: Area under precision-recall curve (imbalanced data)
  • F1-Score: Harmonic mean of precision and recall
  • Balanced Accuracy: Average of sensitivity and specificity
  • MCC: Matthews Correlation Coefficient (robust to imbalance)
  • Brier Score: Probability accuracy
  • ECE: Expected Calibration Error (calibration quality)

Fairness Metrics

  • Performance stratified by:
    • Rural vs Urban
    • Age groups (<20, 20-34, ≥35)
    • Caste (General vs SC/ST)
    • Vulnerability score (0-5)
    • District (all 13 districts)

Baseline Comparisons

Regression Baselines

Model MAE (%) Training Time
Mean Baseline 10.968 0.000 <0.1s
Linear Regression 5.695 0.703 0.12s
GB Regressor 5.338 0.727 8.5s
RF Regressor 5.665 0.693 3.7s
SVR 9.985 0.142 45.2s
MLP 6.034 0.672 28.9s
UttaraRisk-Next 5.335 0.725 12.3s

Classification Baselines (Abortion)

Model ROC-AUC F1-Score ECE
Majority Class 0.500 0.000 N/A
Logistic Regression 0.565 0.419 0.035
GB Classifier 0.536 0.229 0.032
RF Classifier 0.575 0.197 0.041
SVM 0.589 0.423 0.048
MLP 0.494 0.367 0.067
Naive Bayes 0.554 0.392 0.055
UttaraRisk-Next 0.565 0.013 0.020

Note: UttaraRisk-Next prioritizes calibration (ECE 0.020) over F1-score for clinical risk communication.


Visualization

Generate all figures:

python src/create_visualizations_fixed.py

Generated Figures:

  1. data_exploration.png - Dataset distributions and patterns
  2. roc_pr_curves.png - ROC and PR curves for both classification tasks
  3. feature_importance.png - Top 15 features and task-specific importance
  4. calibration_plots.png - Calibration curves demonstrating ECE <0.025
  5. fairness_analysis.png - Performance across demographic groups
  6. risk_distribution.png - Risk stratification and outcome relationships

Fairness Analysis

Equity Validation Across Demographics

Rural vs Urban:

  • Risk difference: 7.1 percentage points (reflects genuine epidemiological variation)
  • Calibration equity: ECE difference <0.005

Age Groups:

  • Teen (<20): Mean risk 42.1%, ECE 0.018
  • Optimal (20-34): Mean risk 37.8%, ECE 0.021
  • Advanced (≥35): Mean risk 41.3%, ECE 0.023

Socioeconomic Vulnerability:

  • Score 0-1: Mean risk 32.4%, ECE 0.019
  • Score 2-3: Mean risk 39.7%, ECE 0.020
  • Score 4-5: Mean risk 47.8%, ECE 0.022

Key Finding: ECE differences <0.025 across all groups demonstrates equitable calibration quality.


Clinical Decision Support

Risk Stratification Thresholds

Risk Tier Probability Range Population (%) Recommended Action
Low Risk <25% 23.4% Standard prenatal care
Moderate Risk 25-50% 54.2% Enhanced monitoring
High Risk 50-75% 19.8% Specialist consultation
Very High Risk >75% 2.6% Intensive management

Interpretation Guidelines

For Abortion Prediction:

  • <15%: Low risk, routine care
  • 15-30%: Moderate risk, enhanced monitoring
  • 30-50%: High risk, specialist consultation
  • 50%: Very high risk, intensive management

For Mortality Prediction:

  • <0.5%: Very low risk (baseline)
  • 0.5-1.0%: Low risk, standard care
  • 1.0-2.0%: Moderate risk, enhanced surveillance
  • 2.0%: High risk, tertiary care referral


Limitations

Data Limitations

  1. Synthetic Data: Proof-of-concept using epidemiologically representative but synthetic data
  2. Missing Variables: Lacks specific complications (preeclampsia, placenta previa), detailed obstetric history
  3. Cross-Sectional: No temporal dynamics or pregnancy progression modeling
  4. Sample Size: Limited positive cases for mortality (16 total, 3 in validation)

Model Limitations

  1. Low F1-Scores: Reflects conservative prediction strategy prioritizing calibration
  2. Mortality Discrimination: Limited by extreme class imbalance (0.6%)
  3. Architecture: Separate task training misses potential joint optimization benefits
  4. Interpretability: Tree-based importance provides limited insight into complex interactions

Deployment Limitations

  1. Prospective Validation Required: Real Uttarakhand clinical data needed before deployment
  2. Regulatory Approval: Medical device certification process required
  3. Healthcare Integration: EHR system integration and workflow optimization needed
  4. Provider Training: Comprehensive education on model use and interpretation essential

Future Directions

  1. Larger Real-World Dataset: Collect 5,000+ clinical records from Uttarakhand hospitals
  2. Longitudinal Modeling: Incorporate temporal features and pregnancy progression
  3. Advanced Architectures: Explore true multi-task learning with shared representations
  4. Enhanced Fairness: Implement fairness constraints during training
  5. Explainability: Integrate SHAP values and counterfactual explanations
  6. Deployment Infrastructure: Mobile app, offline capabilities, EHR integration
  7. Geographic Expansion: Validate for other mountainous regions (Nepal, Bhutan, Himachal Pradesh)
  8. Policy Integration: Pilot implementation with Uttarakhand health department

License

This project is licensed under the MIT License - see the LICENSE file for details.


Contributing

We welcome contributions! Please see CONTRIBUTING.md for guidelines.

Areas for Contribution:

  • Real-world data collection and validation
  • Alternative model architectures
  • Additional fairness metrics
  • Deployment tools and interfaces
  • Documentation improvements

Contact


Acknowledgments

  • Healthcare workers in Uttarakhand for their dedication to maternal health
  • Open-source community (scikit-learn, NumPy, pandas)

Ethical Statement

This is a proof-of-concept research project using synthetic data. The model is NOT approved for clinical use. Prospective validation with real clinical data, regulatory approval, and comprehensive clinical trials are required before any deployment in healthcare settings. No data of individuals has been shared for orignal data uesr can request to competent authorities. Important: This tool is designed to augment, not replace, clinical judgment. Healthcare providers must retain full autonomy in patient care decisions.


SDG Alignment

This project directly contributes to:

  • SDG 3.1: Reduce global maternal mortality ratio
  • SDG 5: Achieve gender equality and empower all women and girls
  • SDG 10: Reduce inequality within and among countries

Last Updated:27 February 2026
Version: 1.1.0
Status: Research / Proof-of-Concept

About

In India’s mountainous areas, maternal mortality is still a serious public health concern, especially in Uttarakhand, where access to healthcare is hampered by geographical obstacles. UttaraRisk-Next, a multi-task ensemble learning framework for thorough maternal health risk assessment, is presented in this paper.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages