The WESAD (Wearable Stress and Affect Detection) Pipeline is a comprehensive machine learning framework designed to process and analyze physiological signals for stress detection. This project leverages multi-modal sensor data collected from wearable devices to extract meaningful features and build robust machine learning models for stress classification.
- Advanced Signal Processing: Extract time-domain, frequency-domain, and entropy features from physiological signals
- Multi-Modal Integration: Process data from both chest and wrist-worn devices, including ECG, EMG, EDA, respiration, temperature, and accelerometer data
- Parallel Processing: Optimize feature extraction with multi-threading for improved performance
- Flexible Model Selection: Choose from a variety of machine learning models through a simple registry system
- Cross-Validation: Evaluate models with rigorous k-fold cross-validation protocols
- Comprehensive Metrics: Generate and save detailed performance reports and visualizations
- Python 3.7+
- NumPy
- Pandas
- SciPy
- Scikit-learn
- Matplotlib
- Seaborn
- AntroPy (for entropy-based features)
- imbalanced-learn
- XGBoost
- SKTime/Aeon (for time series classifiers)
- Set up a virtual environment (recommended):
python -m venv wesad-env
source wesad-env/bin/activate # On Windows: wesad-env\Scripts\activate- Install the required packages:
pip install numpy pandas scipy scikit-learn matplotlib seaborn antropy imbalanced-learn xgboost sktime aeonwesad-pipeline/
├── features/
│ └── feature_extractor.py # Signal feature extraction utilities
├── models/
│ └── model_registry.py # Model registry for ease of use
├── utils/
│ ├── arguments.py # Command-line argument parsing
│ ├── cross_validation.py # Cross-validation implementation
│ ├── data_loader.py # Data loading utilities
│ ├── evaluation.py # Model evaluation and metrics
│ ├── feature_selector.py # Feature selection utilities
│ └── logger.py # Logging configuration
└── main.py # Main application entry point
To run the pipeline with default settings:
python main.py --data_dir /path/to/wesad/data --model_name RandomForestThe pipeline supports various command-line arguments to customize processing:
python main.py \
--data_dir /path/to/wesad/data \
--results_dir results_folder \
--model_name RandomForest \
--n_splits 5 \
--feature_selection \
--binary_classification| Argument | Description | Default |
|---|---|---|
--data_dir |
Directory containing WESAD .pkl files | Required |
--results_dir |
Directory to save results | results_YYYYMMDD_HHMMSS |
--use_cache |
Path to cache file for preloaded data | None |
--drop_non_study |
Drop labels 0, 5, 6, 7 | False |
--shorten_non_study |
Shorten non-study labels to 0 | False |
--n_splits |
Number of CV folds | 5 |
--save_datasets |
Save train/test datasets per fold | False |
--model_name |
Model name in registry | RandomForest |
--feature_selection |
Enable feature selection | False |
--binary_classification |
Use binary classification | False |
--imputer |
Imputation strategy | mean |
The pipeline extracts a comprehensive set of features from physiological signals:
- Basic statistics (mean, std, max, min, skewness, kurtosis)
- Peak analysis (number of peaks, average peak distance)
- For ECG signals: HRV metrics (RMSSD, SDNN, pNN50)
- Band powers (low, mid, high frequency)
- FFT mean and entropy
- Sample entropy
- Higuchi fractal dimension
- Detrended fluctuation analysis
Example of feature extraction from a signal:
# Extract features from ECG signal
features = extract_features(ecg_signal, "chest_ECG", fs=700)
# Example output:
# {
# 'chest_ECG_mean': 0.153,
# 'chest_ECG_std': 0.423,
# 'chest_ECG_max': 1.287,
# 'chest_ECG_min': -0.857,
# 'chest_ECG_skew': 0.374,
# 'chest_ECG_kurtosis': 3.142,
# 'chest_ECG_power_low': 0.035,
# 'chest_ECG_power_mid': 0.021,
# 'chest_ECG_power_high': 0.012,
# 'chest_ECG_fft_mean': 0.023,
# 'chest_ECG_fft_entropy': 0.872,
# 'chest_ECG_num_peaks': 83,
# 'chest_ECG_avg_peak_dist': 8.45
# }The pipeline follows these main steps:
- Data Loading: Load subject data from pickle files
- Window Generation: Create time windows for feature extraction
- Feature Extraction: Extract features from each signal and dimension
- Preprocessing: Handle missing values, scale features
- Feature Selection (optional): Select most important features
- Model Training: Train the selected classifier
- Evaluation: Compute and save performance metrics
The following models are available through the model registry:
- RandomForest: Random Forest Classifier
- SVM: Support Vector Machine with RBF kernel
- KNN: k-Nearest Neighbors
- XGBoost: XGBoost Classifier
- CanonicalIntervalForest: Canonical Interval Forest (time series)
- TimeSeriesForest: Time Series Forest Classifier
- QUANT: QUANT Classifier (time series)
To add a new model, simply update the MODEL_REGISTRY dictionary in models/model_registry.py:
MODEL_REGISTRY = {
# ... existing models ...
"MyNewModel": lambda: MyModelClass(param1=value1, param2=value2)
}The pipeline uses k-fold cross-validation at the subject level, ensuring that data from the same subject doesn't appear in both training and testing sets:
# Example cross-validation logic
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
for fold, (train_idx, test_idx) in enumerate(kf.split(file_keys), start=1):
train_files = [file_keys[i] for i in train_idx]
test_files = [file_keys[i] for i in test_idx]
# Process and train on this fold
# ...Here's an example workflow for a complete analysis:
- Prepare WESAD Dataset: Ensure you have the WESAD dataset as pickle files
- Run Full Cross-Validation:
python main.py \
--data_dir /path/to/wesad/data \
--results_dir stress_analysis_results \
--model_name RandomForest \
--n_splits 5 \
--feature_selection \
--save_datasets- Analyze Results: Examine confusion matrices and classification reports in the results directory
To add new features, extend the extract_signal_features function in features/feature_extractor.py:
def extract_my_new_feature(signal):
# Implement your feature extraction logic
return {"my_new_feature": calculated_value}
def extract_signal_features(win, signal_name, fs):
features = {}
# ... existing code ...
# Add your new feature
my_features = extract_my_new_feature(signal)
features.update({f"{signal_name}_{k}": v for k, v in my_features.items()})
return featuresTo support new signal types, update the SIGNAL_CONFIG dictionary in utils/data_loader.py:
SIGNAL_CONFIG = {
# ... existing signals ...
('new_device', 'NEW_SIGNAL'): {'fs': 256, 'window_s': 1.0},
}If you encounter issues with missing values, try using a different imputation strategy:
python main.py --data_dir /path/to/data --imputer medianFor large datasets, consider processing subjects one by one instead of preloading all data:
python main.py --data_dir /path/to/data --use_cache cache_file.pklIf you encounter issues with library dependencies, ensure you have compatible versions:
pip install -r requirements.txt # Create this file with specific version numbersContributions to improve the WESAD pipeline are welcome! Please follow these steps:
- Fork the repository
- Create a feature branch
- Add your changes
- Submit a pull request
This project is licensed under the GNUv3 License - see the LICENSE file for details.
- The WESAD dataset creators for providing multimodal physiological signals
- Contributors to the scientific libraries used in this project
For questions or support, please open an issue in the repository.
