-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
179 lines (145 loc) · 6.85 KB
/
train.py
File metadata and controls
179 lines (145 loc) · 6.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
CEDD Training Script / Script d'entraînement CEDD
==================================================
Loads synthetic conversations, extracts features, trains the classifier.
Charge les conversations synthétiques, extrait les features, entraîne le classifieur.
"""
import argparse
import json
import os
import sys
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
# Add root directory to path / Ajouter le répertoire racine au path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from cedd.feature_extractor import extract_features, extract_trajectory_features
from cedd.classifier import CEDDClassifier
DATA_PATH = "data/synthetic_conversations.json"
MODEL_PATH = "models/cedd_model.joblib"
LABEL_NAMES = ["green", "yellow", "orange", "red"]
def load_and_extract(data_path: str):
"""
Load conversations and extract trajectory features.
Charge les conversations et extrait les features de trajectoire.
"""
with open(data_path, "r", encoding="utf-8") as f:
conversations = json.load(f)
print(f"Conversations loaded / chargées : {len(conversations)}")
X_list = []
y_list = []
for conv in conversations:
messages = conv["messages"]
label = conv["label"]
label_name = conv["label_name"]
# Extract per-message features, then aggregate into trajectory
# Extraire les features message par message, puis agréger
msg_features = extract_features(messages)
user_texts = [m["content"] for m in messages if m["role"] == "user"]
traj_features = extract_trajectory_features(msg_features, user_texts=user_texts,
messages=messages)
X_list.append(traj_features)
y_list.append(label)
n_user = sum(1 for m in messages if m["role"] == "user")
print(f" [{label_name:6s}] {conv['id']:20s} — {n_user} user msgs, "
f"{len(traj_features)} features")
X = np.array(X_list)
y = np.array(y_list)
print(f"\nX shape : {X.shape}")
print(f"y shape : {y.shape}")
print(f"Label distribution / Distribution des labels : "
f"{ {LABEL_NAMES[i]: int((y == i).sum()) for i in range(4)} }")
return X, y
def print_separator(char="=", length=60):
print(char * length)
def main():
parser = argparse.ArgumentParser(
description="Train the CEDD classifier. / Entraîner le classifieur CEDD."
)
parser.add_argument(
"--data", "-d",
type=str,
default=DATA_PATH,
help="Path to training data JSON. / Chemin vers les données d'entraînement JSON. "
f"Default: {DATA_PATH}",
)
args = parser.parse_args()
data_path = args.data
print_separator()
print(" CEDD — Classifier Training / Entraînement du classifieur")
print_separator()
# 1. Load data and extract features / Charger les données et extraire les features
print(f"\n[1/4] Loading data & extracting features / Chargement et extraction...")
print(f" Data source / Source : {data_path}")
X, y = load_and_extract(data_path)
# 2. Stratified cross-validation / Validation croisée stratifiée
print("\n[2/4] Stratified cross-validation (k=4)...")
clf = CEDDClassifier(n_estimators=200, random_state=42)
cv = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)
cv_scores = cross_val_score(clf.pipeline, X, y, cv=cv, scoring="accuracy")
print(f" Accuracy per fold / par fold : {[f'{s:.3f}' for s in cv_scores]}")
print(f" Mean accuracy / Accuracy moy : {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
# 3. Full training / Entraînement sur l'ensemble complet
print("\n[3/4] Training on full dataset / Entraînement sur l'ensemble complet...")
clf.fit(X, y)
y_pred = clf.predict(X)
train_accuracy = accuracy_score(y, y_pred)
print(f"\n Train accuracy : {train_accuracy:.3f}")
print()
print_separator("-")
print(" Classification report / Rapport de classification (train)")
print_separator("-")
print(classification_report(y, y_pred, target_names=LABEL_NAMES, digits=3))
print(" Confusion matrix / Matrice de confusion (train)")
print_separator("-")
cm = confusion_matrix(y, y_pred)
header = f"{'':10s}" + " ".join(f"{n:>8s}" for n in LABEL_NAMES)
print(header)
for i, row in enumerate(cm):
row_str = f" {LABEL_NAMES[i]:8s}" + " ".join(f"{v:8d}" for v in row)
print(row_str)
# 4. Feature importances / Importance des features
print()
print_separator("-")
print(" Top 10 most important features / Top 10 features les plus importantes")
print_separator("-")
rf = clf.pipeline.named_steps["clf"]
importances = rf.feature_importances_
top_idx = np.argsort(importances)[::-1][:10]
for rank, idx in enumerate(top_idx, 1):
if idx < len(clf.feature_names):
fname = clf.feature_names[idx]
print(f" {rank:2d}. {fname:35s} {importances[idx]:.4f}")
# 5. Save model / Sauvegarde
print(f"\n[4/4] Saving model / Sauvegarde du modèle → {MODEL_PATH}...")
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
clf.save(MODEL_PATH)
# 6. Reload test / Test de rechargement
print("\n Reload test / Test de rechargement...")
clf2 = CEDDClassifier.load(MODEL_PATH)
test_conv = [
{"role": "user", "content": "je sais plus pourquoi je continue. plus rien a de sens."},
{"role": "assistant", "content": "Je suis là. Qu'est-ce qui se passe ?"},
{"role": "user", "content": "jai pensé à en finir. jai un plan ce soir."},
]
result = clf2.get_alert_level(test_conv, lang="fr")
print(f" Crisis test / Test alerte critique : {result['label']} (confidence: {result['confidence']:.2f})")
print(f" Dominant features : {result['dominant_features']}")
# English crisis test / Test de crise en anglais
test_conv_en = [
{"role": "user", "content": "I don't know why I keep going. nothing makes sense anymore."},
{"role": "assistant", "content": "I'm here. What's going on?"},
{"role": "user", "content": "I've thought about ending it. I have a plan for tonight."},
]
result_en = clf2.get_alert_level(test_conv_en, lang="en")
print(f" EN crisis test : {result_en['label']} (confidence: {result_en['confidence']:.2f})")
print(f" Dominant features: {result_en['dominant_features']}")
print()
print_separator()
print(" Training complete! / Entraînement terminé avec succès !")
print(f" CV Accuracy : {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
print(f" Train Accuracy : {train_accuracy:.3f}")
print(f" Model / Modèle : {MODEL_PATH}")
print_separator()
if __name__ == "__main__":
main()