-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
202 lines (166 loc) · 7.79 KB
/
app.py
File metadata and controls
202 lines (166 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from flask import Flask, request, jsonify, render_template
from groq import Groq
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import os
app = Flask(__name__)
# Clé API Groq (remplacez par votre propre clé)
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "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")
client = Groq(api_key=GROQ_API_KEY)
# Chargement du dataset local
dataset_path = "dataset.csv"
if not os.path.exists(dataset_path):
print("❌ Erreur : le fichier dataset.csv est introuvable !")
exit(1)
try:
dataset = pd.read_csv(dataset_path)
print(f"✅ Dataset chargé avec {len(dataset)} lignes et {len(dataset.columns)} colonnes.")
except Exception as e:
print(f"❌ Erreur lors du chargement du dataset : {str(e)}")
exit(1)
# Vérification de la colonne "Disease"
if "Disease" not in dataset.columns:
print("❌ Erreur : la colonne 'Disease' est introuvable dans le dataset.")
exit(1)
# Liste des maladies disponibles
disease_names = dataset["Disease"].astype(str).unique().tolist()
# Vectorisation des maladies pour la recherche
disease_vectorizer = TfidfVectorizer()
disease_vectors = disease_vectorizer.fit_transform(disease_names)
# Extraire tous les symptômes uniques du dataset
symptom_columns = [col for col in dataset.columns if col.startswith("Symptom_")]
all_symptoms = set()
for col in symptom_columns:
symptoms = dataset[col].dropna().unique()
all_symptoms.update(symptoms)
all_symptoms = list(all_symptoms)
# Vectorisation des symptômes pour la recherche
symptom_vectorizer = TfidfVectorizer()
symptom_vectors = symptom_vectorizer.fit_transform(all_symptoms)
def find_disease_by_name(disease_query, min_similarity=0.3):
"""Recherche la maladie la plus proche du nom donné."""
query_vector = disease_vectorizer.transform([disease_query])
similarities = cosine_similarity(query_vector, disease_vectors).flatten()
most_similar_idx = np.argmax(similarities)
max_similarity = similarities[most_similar_idx]
if max_similarity >= min_similarity:
return disease_names[most_similar_idx], max_similarity
else:
return None, max_similarity
def find_symptom(symptom_query, min_similarity=0.3):
"""Recherche le symptôme le plus proche du nom donné."""
query_vector = symptom_vectorizer.transform([symptom_query])
similarities = cosine_similarity(query_vector, symptom_vectors).flatten()
most_similar_idx = np.argmax(similarities)
max_similarity = similarities[most_similar_idx]
if max_similarity >= min_similarity:
return all_symptoms[most_similar_idx], max_similarity
else:
return None, max_similarity
def find_diseases_by_symptom(symptom):
"""Trouve toutes les maladies associées à un symptôme donné."""
related_diseases = set()
for _, row in dataset.iterrows():
for col in symptom_columns:
if pd.notna(row[col]) and row[col] == symptom:
related_diseases.add(row["Disease"])
break
return list(related_diseases)
def get_symptoms_for_disease(disease_name):
"""Récupère les symptômes associés à une maladie."""
disease_match, similarity = find_disease_by_name(disease_name)
if disease_match is None:
# Essayons de voir si c'est un symptôme
symptom_match, symptom_similarity = find_symptom(disease_name)
if symptom_match and symptom_similarity >= 0.3:
related_diseases = find_diseases_by_symptom(symptom_match)
if related_diseases:
diseases_list = ", ".join(related_diseases)
message = f"Le symptôme '{symptom_match}' est associé aux maladies suivantes :\n\n{diseases_list}"
try:
completion = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": "Tu es un assistant médical. Explique quelles maladies sont associées à ce symptôme et comment elles se manifestent."},
{"role": "user", "content": message},
],
temperature=0.2,
max_tokens=512,
)
return completion.choices[0].message.content.strip()
except Exception as e:
return f"❌ Erreur API Groq : {str(e)}"
else:
return f"Aucune maladie associée au symptôme '{symptom_match}' n'a été trouvée."
else:
return f"Aucune correspondance trouvée pour '{disease_name}'."
# Filtrer le dataset pour la maladie trouvée
disease_data = dataset[dataset["Disease"] == disease_match]
if len(disease_data) == 0:
return f"Données introuvables pour {disease_match}."
# Extraire tous les symptômes pour cette maladie
all_symptoms_for_disease = []
for _, row in disease_data.iterrows():
for col in symptom_columns:
if pd.notna(row[col]): # Vérifier que la valeur n'est pas NaN
symptom = row[col]
if symptom not in all_symptoms_for_disease: # Éviter les doublons
all_symptoms_for_disease.append(symptom)
if not all_symptoms_for_disease:
return f"Aucun symptôme trouvé pour {disease_match}."
# Génération du message de réponse
symptom_list = ", ".join(all_symptoms_for_disease)
message = f"Maladie : {disease_match}\n\nSymptômes : {symptom_list}"
try:
completion = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": "Tu es un assistant médical. Liste uniquement les symptômes de la maladie mentionnée."},
{"role": "user", "content": message},
],
temperature=0.2,
max_tokens=512,
)
return completion.choices[0].message.content.strip()
except Exception as e:
return f"❌ Erreur API Groq : {str(e)}"
@app.route('/')
def index():
return render_template("index.html")
@app.route('/chat', methods=['POST'])
def chat():
try:
data = request.get_json()
if not data:
return jsonify({"error": "Données JSON invalides"}), 400
user_input = data.get("message", "").strip()
if not user_input:
return jsonify({"error": "Aucun message envoyé"}), 400
response = get_symptoms_for_disease(user_input)
return jsonify({"response": response})
except Exception as e:
return jsonify({"error": f"Erreur serveur: {str(e)}"}), 500
@app.route('/maladies', methods=['GET'])
def list_maladies():
"""Endpoint pour récupérer la liste des maladies disponibles"""
try:
return jsonify({"maladies": disease_names})
except Exception as e:
return jsonify({"error": f"Erreur serveur: {str(e)}"}), 500
@app.route('/symptomes', methods=['GET'])
def list_symptomes():
"""Endpoint pour récupérer la liste des symptômes disponibles"""
try:
return jsonify({"symptomes": list(all_symptoms)})
except Exception as e:
return jsonify({"error": f"Erreur serveur: {str(e)}"}), 500
if __name__ == '__main__':
# Affichage d'informations utiles au démarrage
print("🔍 Nombre de maladies uniques :", len(disease_names))
print("🔍 Exemples de maladies :", disease_names[:5])
print("🔍 Nombre de symptômes uniques :", len(all_symptoms))
print("🔍 Exemples de symptômes :", list(all_symptoms)[:5])
print("\n✅ Application démarrée ! Accédez à http://127.0.0.1:5000/ pour commencer.")
app.run(debug=True)