A binary classification model built with PyTorch that predicts the likelihood of a stroke based on patient health data. The model is a feed-forward neural network trained on the Healthcare Dataset Stroke Data, which includes demographic and clinical features such as age, BMI, smoking status, and average glucose level.
- Overview
- Dataset
- Data Preprocessing
- Model Architecture
- Training
- Project Structure
- Getting Started
- Tech Stack
- License
Stroke is one of the leading causes of death and long-term disability worldwide. Early prediction based on patient risk factors can aid in preventive healthcare and timely medical intervention.
This project builds a neural network classifier that:
- Loads and preprocesses patient health records from a CSV dataset
- Encodes categorical features and scales numerical ones
- Trains a fully connected neural network for binary stroke prediction
- Evaluates performance using binary accuracy via
torchmetrics
The dataset (healthcare-dataset-stroke-data.csv) contains patient records with the following features:
| Feature | Type | Description |
|---|---|---|
gender |
Categorical | Male, Female, or Other |
age |
Numerical | Age of the patient |
hypertension |
Binary | 0 = No, 1 = Yes |
heart_disease |
Binary | 0 = No, 1 = Yes |
ever_married |
Categorical | Yes or No |
work_type |
Categorical | Private, Self-employed, Govt_job, children, Never_worked |
Residence_type |
Categorical | Urban or Rural |
avg_glucose_level |
Numerical | Average blood glucose level |
bmi |
Numerical | Body Mass Index |
smoking_status |
Categorical | formerly smoked, never smoked, smokes, Unknown |
stroke |
Binary (Target) | 0 = No stroke, 1 = Stroke |
The following preprocessing steps are applied in src/data_processing.py:
- Missing value imputation — Missing
bmivalues are filled with the column mean - Categorical encoding — Label Encoding is applied to:
gender,ever_married,work_type,smoking_status,Residence_type
- Feature scaling — Numerical features are manually scaled:
age/ 100avg_glucose_level/ 1000bmi/ 100
- Feature/target separation —
idandstrokecolumns are dropped from input features;strokeis used as the target variable
A feed-forward neural network defined in src/model.py:
Input (10 features)
|
Linear(10 → 32) → ReLU
|
Linear(32 → 32) → ReLU
|
Linear(32 → 1)
|
Output (stroke probability)
| Component | Details |
|---|---|
| Input features | 10 |
| Hidden layers | 2 (32 neurons each) |
| Activation | ReLU |
| Output | 1 neuron (binary classification) |
| Parameter | Value |
|---|---|
| Loss function | BCEWithLogitsLoss |
| Optimizer | SGD |
| Learning rate | 0.1 |
| Epochs | 10,000 |
| Train/test split | 80/20 (random_state=42) |
| Accuracy metric | torchmetrics.Accuracy (binary) |
| Device | CUDA (GPU) if available, otherwise CPU |
Training progress is printed every 1,000 epochs, showing train loss, train accuracy, test loss, and test accuracy.
StrokePredictionModel/
├── data/
│ ├── __init__.py
│ └── raw/
│ └── healthcare-dataset-stroke-data.csv
├── src/
│ ├── __init__.py
│ ├── data_processing.py # Data loading, encoding, scaling
│ └── model.py # Neural network definition & training loop
├── main.py # Entry point (empty)
├── Pipfile # Pipenv dependency file
├── Pipfile.lock
├── .gitignore
└── README.md
- Python 3.13+
- Pipenv (recommended) or pip
-
Clone the repository:
git clone https://github.com/WiktoriaSmulska/StrokePredictionModel.git cd StrokePredictionModel -
Install dependencies with Pipenv:
pipenv install pipenv shell
Or with pip:
pip install torch torchmetrics pandas numpy scikit-learn imbalanced-learn
-
Run the model:
python -m src.model
| Technology | Purpose |
|---|---|
| Python 3.13 | Programming language |
| PyTorch | Neural network framework |
| torchmetrics | Accuracy metric computation |
| pandas | Data loading and manipulation |
| NumPy | Numerical operations |
| scikit-learn | Label encoding, train/test split |
| imbalanced-learn | Handling class imbalance (SMOTE) |
| Pipenv | Dependency management |
This project is open source and available for academic and educational purposes.
Wiktoria Smulska — GitHub