-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalysis.py
More file actions
161 lines (141 loc) · 6.56 KB
/
analysis.py
File metadata and controls
161 lines (141 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import os
import warnings
from sklearn.metrics import confusion_matrix
# 导入你的 XGBoost 类(需要确保 SepsisXGBoost 可用)
# 假设 XGBoost.py 在同一目录下
from XGBoost import SepsisXGBoost
# ========== 配置路径 ==========
model_path = r'C:\Users\jsntg\OneDrive\Desktop\3YP\code\result\XGB06.pkl' # 训练好的模型
val_csv_path = r'C:\Users\jsntg\OneDrive\Desktop\3YP\code\data\Validation.csv'
graph_path = r'C:\Users\jsntg\OneDrive\Desktop\3YP\code\graph'
result_path = r'C:\Users\jsntg\OneDrive\Desktop\3YP\code\result'
# ========== 1. 加载模型 ==========
model = SepsisXGBoost()
model.model = joblib.load(model_path)
print("模型加载完成")
# ========== 2. 加载验证集 ==========
df_val = pd.read_csv(val_csv_path)
df_val['timestamp'] = pd.to_datetime(df_val['timestamp'])
print(f"验证集原始数据 shape: {df_val.shape}")
print(f"验证集中 hospital_admission_id 数量: {df_val['hospital_admission_id'].nunique()}")
print(f"验证集中 is_onset=1 的数量: {df_val['is_onset'].sum()}")
# ========== 3. 生成未来标签(与训练时参数一致) ==========
df_val = SepsisXGBoost.add_future_sepsis_label(df_val, future_hours=3, onset_col='is_onset')
print("未来标签生成完成")
# ========== 4. 构造窗口统计特征 ==========
X_val, y_val, y_val_orig, val_pids, val_adm_ids, val_times, feature_names = model._create_window_stat_features(
df_val, window_hours=3, min_records=2
)
print(f"验证集特征矩阵 shape: {X_val.shape}")
# ========== 5. 预测 ==========
y_pred = model.model.predict(X_val)
y_prob = model.model.predict_proba(X_val)[:, 1]
# 构建结果 DataFrame
results = pd.DataFrame({
'patient_id': val_pids,
'hospital_admission_id': val_adm_ids,
'timestamp': val_times,
'true_label_future': y_val, # 未来标签(模型训练目标)
'true_label_original': y_val_orig, # 原始 sepsis 列(入院级别)
'pred_label': y_pred,
'pred_prob': y_prob
})
# 合并 is_onset(从原始验证集中获取)
onset_info = df_val[['hospital_admission_id', 'timestamp', 'is_onset']].copy()
onset_info['timestamp'] = pd.to_datetime(onset_info['timestamp'])
results['timestamp'] = pd.to_datetime(results['timestamp'])
results = results.merge(onset_info, on=['hospital_admission_id', 'timestamp'], how='left')
# 保存完整预测结果
results.to_csv(os.path.join(result_path, 'validation_predictions.csv'), index=False)
print("预测结果已保存至 validation_predictions.csv")
# ========== 6. 入院级别混淆矩阵(基于原始 sepsis 标签) ==========
admission_true = results.groupby('hospital_admission_id')['true_label_original'].max()
admission_pred = results.groupby('hospital_admission_id')['pred_label'].max()
cm = confusion_matrix(admission_true, admission_pred)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title('Validation Set: Confusion Matrix per Admission (Original Sepsis)')
plt.xlabel('Predicted (model ever warned)')
plt.ylabel('True (sepsis occurred)')
plt.tight_layout()
save_path_cm = os.path.join(graph_path, 'wow.png')
plt.savefig(save_path_cm, dpi=300)
plt.show()
print(f"混淆矩阵保存至:{save_path_cm}")
print("混淆矩阵数值:")
print(cm)
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp+fn)>0 else 0
specificity = tn / (tn + fp) if (tn+fp)>0 else 0
precision = tp / (tp + fp) if (tp+fp)>0 else 0
print(f"\n入院级别性能:")
print(f"Sensitivity (Recall): {sensitivity:.3f}")
print(f"Specificity: {specificity:.3f}")
print(f"Precision: {precision:.3f}")
# ========== 7. 预警时间分析(基于 is_onset) ==========
# 只考虑真实发生过脓毒症的入院(true_label_original == 1)
sepsis_admissions = admission_true[admission_true == 1].index.tolist()
deltas = []
details = []
never_warned = 0
for adm in sepsis_admissions:
adm_data = results[results['hospital_admission_id'] == adm].sort_values('timestamp')
# 真实发病时刻:is_onset == 1 的时间
true_onset = adm_data[adm_data['is_onset'] == 1]['timestamp']
if true_onset.empty:
# 没有标记发病时刻(可能数据问题),跳过
continue
true_onset = true_onset.iloc[0]
# 模型首次预警时间(pred_label == 1)
pred_warning = adm_data[adm_data['pred_label'] == 1]['timestamp']
if pred_warning.empty:
never_warned += 1
details.append({'admission': adm, 'true_onset': true_onset, 'pred_warning': None, 'delta_hours': None})
else:
pred_warning = pred_warning.iloc[0]
delta = (pred_warning - true_onset).total_seconds() / 3600.0
deltas.append(delta)
details.append({'admission': adm, 'true_onset': true_onset, 'pred_warning': pred_warning, 'delta_hours': delta})
deltas = np.array(deltas)
print("\n" + "="*60)
print("预警时间差分析(仅限模型曾预警的真实脓毒症入院)")
print("负值表示提前预警,正值表示延迟预警")
print("="*60)
if len(deltas) > 0:
print(f"有效预警入院数:{len(deltas)}")
print(f"Mean delta: {np.mean(deltas):.2f} hours")
print(f"Median delta: {np.median(deltas):.2f} hours")
print(f"Std: {np.std(deltas):.2f} hours")
print(f"Min (earliest): {np.min(deltas):.2f} hours")
print(f"Max (latest): {np.max(deltas):.2f} hours")
print(f"25th percentile: {np.percentile(deltas, 25):.2f} hours")
print(f"75th percentile: {np.percentile(deltas, 75):.2f} hours")
print(f"提前预警比例 (delta < 0): {np.sum(deltas < 0) / len(deltas) * 100:.1f}%")
else:
print("没有模型预警的真实脓毒症入院")
print(f"从未预警的真实脓毒症入院数:{never_warned}")
# 保存详细结果
detail_df = pd.DataFrame(details)
detail_df.to_csv(os.path.join(result_path, 'warning_time_details.csv'), index=False)
print("详细预警时间结果已保存至 warning_time_details.csv")
# 绘制 delta 直方图
if len(deltas) > 0:
plt.figure(figsize=(8,5))
plt.hist(deltas, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
plt.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero (exact onset)')
plt.xlabel('Warning Time Delta (hours)')
plt.ylabel('Number of Admissions')
plt.title('Validation Set: Warning Timing Distribution (Negative = Early)')
plt.legend()
plt.grid(alpha=0.3)
save_path_hist = os.path.join(graph_path, 'woow.png')
plt.savefig(save_path_hist, dpi=300)
plt.show()
print(f"直方图保存至:{save_path_hist}")
else:
print("没有足够的预警数据绘制直方图")