Skip to content

WiktoriaSmulska/StrokePredictionModel

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stroke Prediction Neural Network

A binary classification model built with PyTorch that predicts the likelihood of a stroke based on patient health data. The model is a feed-forward neural network trained on the Healthcare Dataset Stroke Data, which includes demographic and clinical features such as age, BMI, smoking status, and average glucose level.


Table of Contents


Overview

Stroke is one of the leading causes of death and long-term disability worldwide. Early prediction based on patient risk factors can aid in preventive healthcare and timely medical intervention.

This project builds a neural network classifier that:

  • Loads and preprocesses patient health records from a CSV dataset
  • Encodes categorical features and scales numerical ones
  • Trains a fully connected neural network for binary stroke prediction
  • Evaluates performance using binary accuracy via torchmetrics

Dataset

The dataset (healthcare-dataset-stroke-data.csv) contains patient records with the following features:

Feature Type Description
gender Categorical Male, Female, or Other
age Numerical Age of the patient
hypertension Binary 0 = No, 1 = Yes
heart_disease Binary 0 = No, 1 = Yes
ever_married Categorical Yes or No
work_type Categorical Private, Self-employed, Govt_job, children, Never_worked
Residence_type Categorical Urban or Rural
avg_glucose_level Numerical Average blood glucose level
bmi Numerical Body Mass Index
smoking_status Categorical formerly smoked, never smoked, smokes, Unknown
stroke Binary (Target) 0 = No stroke, 1 = Stroke

Data Preprocessing

The following preprocessing steps are applied in src/data_processing.py:

  1. Missing value imputation — Missing bmi values are filled with the column mean
  2. Categorical encoding — Label Encoding is applied to:
    • gender, ever_married, work_type, smoking_status, Residence_type
  3. Feature scaling — Numerical features are manually scaled:
    • age / 100
    • avg_glucose_level / 1000
    • bmi / 100
  4. Feature/target separationid and stroke columns are dropped from input features; stroke is used as the target variable

Model Architecture

A feed-forward neural network defined in src/model.py:

Input (10 features)
     |
Linear(10 → 32) → ReLU
     |
Linear(32 → 32) → ReLU
     |
Linear(32 → 1)
     |
Output (stroke probability)
Component Details
Input features 10
Hidden layers 2 (32 neurons each)
Activation ReLU
Output 1 neuron (binary classification)

Training

Parameter Value
Loss function BCEWithLogitsLoss
Optimizer SGD
Learning rate 0.1
Epochs 10,000
Train/test split 80/20 (random_state=42)
Accuracy metric torchmetrics.Accuracy (binary)
Device CUDA (GPU) if available, otherwise CPU

Training progress is printed every 1,000 epochs, showing train loss, train accuracy, test loss, and test accuracy.


Project Structure

StrokePredictionModel/
├── data/
│   ├── __init__.py
│   └── raw/
│       └── healthcare-dataset-stroke-data.csv
├── src/
│   ├── __init__.py
│   ├── data_processing.py       # Data loading, encoding, scaling
│   └── model.py                 # Neural network definition & training loop
├── main.py                      # Entry point (empty)
├── Pipfile                      # Pipenv dependency file
├── Pipfile.lock
├── .gitignore
└── README.md

Getting Started

Prerequisites

  • Python 3.13+
  • Pipenv (recommended) or pip

Installation

  1. Clone the repository:

    git clone https://github.com/WiktoriaSmulska/StrokePredictionModel.git
    cd StrokePredictionModel
  2. Install dependencies with Pipenv:

    pipenv install
    pipenv shell

    Or with pip:

    pip install torch torchmetrics pandas numpy scikit-learn imbalanced-learn
  3. Run the model:

    python -m src.model

Tech Stack

Technology Purpose
Python 3.13 Programming language
PyTorch Neural network framework
torchmetrics Accuracy metric computation
pandas Data loading and manipulation
NumPy Numerical operations
scikit-learn Label encoding, train/test split
imbalanced-learn Handling class imbalance (SMOTE)
Pipenv Dependency management

License

This project is open source and available for academic and educational purposes.


Author

Wiktoria SmulskaGitHub

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages