-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathstroke-prediction.Rmd
More file actions
221 lines (161 loc) · 8.09 KB
/
stroke-prediction.Rmd
File metadata and controls
221 lines (161 loc) · 8.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
---
title: "Stroke Prediction"
author: "Adnan Hakim, Arsh Shaikh, Hussain Sadriwalla"
date: "17/04/2021"
output: html_document
knit: (function(input_file, encoding) {
out_dir <- 'docs';
rmarkdown::render(input_file, encoding = encoding, output_file = file.path(dirname(input_file), out_dir, 'index.html'))})
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
### Importing packages
```{r}
library(tidyverse) # Collection of R packages for data science
library(naniar) # Data structures and functions for plotting of missing values
library(caTools) # Several basic utility functions
library(ggplot2) # Data visualisations Using the Grammar of Graphics
library(superheat) # Generating customizable heatmaps
library(scatterplot3d) # Plots a three dimensional point cloud
library(ROCR) # Creating cutoff-parameterized 2D performance curves
```
## Dataset
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.
The dataset can be found in the [repository](https://github.com/adnanhakim/stroke-prediction) or can be downloaded from [Kaggle](https://www.kaggle.com/fedesoriano/stroke-prediction-dataset)
```{r}
data = read.csv("stroke_data.csv")
str(data)
glimpse(data)
```
## Data Preprocessing
### Checking dataset values
#### Attribute Information
1. id: unique identifier
1. gender: "Male", "Female" or "Other"
1. age: age of the patient
1. hypertension: 0 if the patient doesn't have hypertension, 1 if the patient has hypertension
1. heart_disease: 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease
1. ever_married: "No" or "Yes"
1. Residence_type: "Rural" or "Urban"
1. avg_glucose_level: average glucose level in blood
1. bmi: body mass index
1. smoking_status: "formerly smoked", "never smoked", "smokes" or "Unknown"*
1. stroke: 1 if the patient had a stroke or 0 if not
*Note: "Unknown" in smoking_status means that the information is unavailable for this patient*
```{r}
unique(data $ gender)
unique(data $ ever_married)
unique(data $ Residence_type)
unique(data $ smoking_status)
```
### Converting character values to numeric values
As seen in the above values, the character values can be converted into numeric values.
```{r}
clean_data <- data %>% mutate(gender = if_else(gender == "Female", 0, if_else(gender == "Male", 1, 2)), ever_married = if_else(ever_married == "Yes", 1, 0), Residence_type = if_else(Residence_type == "Rural", 0, 1), smoking_status = if_else(smoking_status == "never smoked", 0, if_else(smoking_status == "formerly smoked", 1, if_else(smoking_status == "smokes", 2, 3))))
summary(clean_data)
```
### Handling missing values
```{r}
miss_scan_count(data = data, search = list("N/A", "Unknown"))
```
There are 201 "N/A" values in the bmi column that likely caused this column to be parsed as character, although it should be numerical. Let's take care of that by replacing those values with actual NAs. Moreover, there are a lot of "Unknown" values in smoking_status which we have to take care of too. We see that we have 1544 unknown values for smoking status and therefore are missing a lot of information in a potentially informative predictor. We will have to deal with this. Lets replace those values with NAs.
```{r}
clean_data <- replace_with_na(data = clean_data, replace = list(bmi = c("N/A"), smoking_status = c(3))) %>% mutate(bmi = as.numeric(bmi))
summary(clean_data)
```
### Visualizing the input
#### Heatmap
```{r}
superheat(subset(clean_data, select = -c(stroke)), scale = TRUE, bottom.label.size = 0.5, bottom.label.text.angle = 90, bottom.label.text.size = 3)
```
#### BMI Distribution
We see that the missingness doesn't show clear association with other variables and therefore we can assume this missingness is MCAR (missing completely at random). The distribution is right skewed (long tail to the right) as this is the only variable with missing data (at least of the numerical variables).
```{r}
ggplot(clean_data, aes(x = bmi)) + geom_density(color="black", fill="lightblue") + labs(title = "Distribution of BMI")
```
#### Gender Distribution
```{r}
ggplot(data, aes(x = factor(gender), fill = factor(gender))) + geom_bar() + theme_classic()
```
#### Age and BMI wrt stroke
```{r}
ggplot(clean_data, aes(x = age, y = bmi, color = stroke)) + geom_point() + scale_color_gradient(low = "lightblue", high = "red")
```
#### Avg Glucose Level with stroke
```{r}
ggplot(clean_data, aes(x = stroke, y = avg_glucose_level, group = stroke, fill = stroke)) + geom_boxplot()
```
## Logistic Regression
```{r}
set.seed(99) # Set a seed for reproducible results
split = sample.split(clean_data $ stroke, SplitRatio = 0.7)
train = subset(clean_data, split == TRUE)
test = subset(clean_data, split == FALSE)
logistic_regression_1 = glm(stroke~., data = train, family = 'binomial')
summary(logistic_regression_1)
```
A lot of variables are not significant. Hence we will be removing Variables based on significance level. The least significant variable as seen is Residence_type with a Pr-value of 0.78326. Hence we will remove Residence_type.
```{r}
logistic_regression_2 = glm(stroke ~ gender + age + hypertension + heart_disease + ever_married + avg_glucose_level + bmi + smoking_status, data = train, family = 'binomial')
summary(logistic_regression_2)
```
The least significant variable as seen is gender with a Pr-value of 0.33212. Hence we will remove gender.
```{r}
logistic_regression_2 = glm(stroke ~ age + hypertension + heart_disease + ever_married + avg_glucose_level + bmi + smoking_status, data = train, family = 'binomial')
summary(logistic_regression_2)
```
The least significant variable as seen is bmi with a Pr-value of 0.29928. Hence we will remove bmi.
```{r}
logistic_regression_2 = glm(stroke ~ age + hypertension + heart_disease + ever_married + avg_glucose_level + smoking_status, data = train, family = 'binomial')
summary(logistic_regression_2)
```
The least significant variable as seen is heart_disease with a Pr-value of 0.44624. Hence we will remove heart_disease.
```{r}
logistic_regression_2 = glm(stroke ~ age + hypertension + ever_married + avg_glucose_level + smoking_status, data = train, family = 'binomial')
summary(logistic_regression_2)
```
The least significant variable as seen is ever_married with a Pr-value of 0.14813. Hence we will remove ever_married.
```{r}
logistic_regression_2 = glm(stroke ~ age + hypertension + avg_glucose_level + smoking_status, data = train, family = 'binomial')
summary(logistic_regression_2)
```
The least significant variable as seen is smoking_status with a Pr-value of 0.12771. Hence we will remove smoking_status.
```{r}
logistic_regression_2 = glm(stroke ~ age + hypertension + avg_glucose_level, data = train, family = 'binomial')
summary(logistic_regression_2)
```
Hence we get the three most significant variables having Pr-values less than 0.05.
### Predictions on training set and confusion matrix
```{r}
predict_train = predict(logistic_regression_2, type = 'response')
table(train $ stroke, predict_train>0.2)
```
### Accuracy on training set
```{r}
(3274 + 38) / nrow(train)
```
### Predictions on test set
```{r}
predict_test = predict(logistic_regression_2, newdata = test, type = 'response')
table(test $ stroke, predict_test>0.2)
```
### Accuracy on testing set
```{r}
(1414 + 18) / nrow(test)
```
## Plotting results
### ROCR Curve and area under the curve
```{r}
rocr_prediction = prediction(predict_test, test $ stroke)
auc = as.numeric(performance(rocr_prediction, 'auc') @ y.values)
auc
```
```{r}
rocr_performance = performance(rocr_prediction, 'tpr','fpr')
plot(rocr_performance, colorize = TRUE, main = 'ROCR Curve')
```
### 3D Scatterplot
```{r}
with(clean_data, {scatterplot3d(x = age, y = hypertension, z = avg_glucose_level, main = "Stroke Prediction Scatterplot", xlab = "Age", ylab = "Hypertension", zlab = "Average Glucose Level")})
```