|
5 | 5 | from pydantic import BaseModel |
6 | 6 | import os |
7 | 7 | import sys |
| 8 | +import uuid |
| 9 | +from typing import Dict, List, Tuple |
8 | 10 |
|
9 | 11 | # Ensure foundational_brain is importable |
10 | 12 | MODEL_ROOT = os.path.dirname(os.path.dirname(__file__)) |
|
16 | 18 | from medical_diagnosis_model.versions.v2.medical_neural_network_v2 import ClinicalReasoningNetwork |
17 | 19 | from medical_diagnosis_model.pdf_exporter import PDFExporter |
18 | 20 | from medical_diagnosis_model.backend.security.jwt_dep import verify_bearer |
| 21 | +from medical_diagnosis_model.versions.v2.medical_disease_schema_v2 import DISEASES_V2 |
| 22 | +from medical_diagnosis_model.medical_symptom_schema import SYMPTOMS |
19 | 23 |
|
20 | 24 |
|
21 | 25 | app = FastAPI(title="Medical Diagnosis API", version="0.1.0") |
|
33 | 37 | MODEL_PATH = os.path.join(MODEL_ROOT, "models", "enhanced_medical_model.json") |
34 | 38 | exporter = PDFExporter(export_dir=os.path.join(MODEL_ROOT, "exports")) |
35 | 39 | _RATE_LIMIT_STORE: dict[str, dict[str, float | int]] = {} |
| 40 | +_ADAPTIVE_SESSIONS: Dict[str, Dict] = {} |
36 | 41 |
|
37 | 42 |
|
38 | 43 | def _ensure_model_loaded(): |
@@ -127,3 +132,219 @@ def export_report(req: ExportRequest, x_api_key: str | None = Header(default=Non |
127 | 132 | return {"path": path} |
128 | 133 |
|
129 | 134 |
|
| 135 | +# ===================== Adaptive (alpha) ===================== |
| 136 | + |
| 137 | +class AdaptiveStartRequest(BaseModel): |
| 138 | + prior_answers: dict | None = None # {"Fever": "yes"|"no"|"unknown"|number} |
| 139 | + threshold: float | None = None # stop threshold for top-1 prob |
| 140 | + max_questions: int | None = None |
| 141 | + |
| 142 | + |
| 143 | +class AdaptiveStartResponse(BaseModel): |
| 144 | + session_id: str |
| 145 | + next_question: dict | None = None # {symptom_id, name} |
| 146 | + |
| 147 | + |
| 148 | +class AdaptiveAnswerRequest(BaseModel): |
| 149 | + session_id: str |
| 150 | + question: str | int |
| 151 | + answer: str # yes|no|unknown |
| 152 | + severity: float | None = None # 0-10 scale (optional) |
| 153 | + |
| 154 | + |
| 155 | +class AdaptiveAnswerResponse(BaseModel): |
| 156 | + session_id: str |
| 157 | + finished: bool |
| 158 | + next_question: dict | None = None |
| 159 | + results: dict | None = None |
| 160 | + |
| 161 | + |
| 162 | +def _symptom_id_from_key(key: str | int) -> int | None: |
| 163 | + if isinstance(key, int): |
| 164 | + return key if key in SYMPTOMS else None |
| 165 | + # Try exact name match |
| 166 | + for sid, meta in SYMPTOMS.items(): |
| 167 | + if meta.get("name", "").lower() == str(key).lower(): |
| 168 | + return sid |
| 169 | + return None |
| 170 | + |
| 171 | + |
| 172 | +def _answers_to_vectors(answers: Dict[int, dict]) -> tuple[list[int], list[float], list[int]]: |
| 173 | + symptom_vector = [0] * 30 |
| 174 | + severity_vector = [0.0] * 30 |
| 175 | + present_ids: list[int] = [] |
| 176 | + for sid, info in answers.items(): |
| 177 | + ans = info.get("answer") |
| 178 | + sev_raw = info.get("severity") |
| 179 | + if ans == "yes": |
| 180 | + symptom_vector[sid] = 1 |
| 181 | + present_ids.append(sid) |
| 182 | + if sev_raw is None: |
| 183 | + severity_vector[sid] = 0.6 |
| 184 | + else: |
| 185 | + try: |
| 186 | + severity_vector[sid] = max(0.0, min(float(sev_raw) / 10.0, 1.0)) |
| 187 | + except Exception: |
| 188 | + severity_vector[sid] = 0.6 |
| 189 | + elif ans == "no": |
| 190 | + # explicitly absent → keep present=0, severity=0 |
| 191 | + continue |
| 192 | + else: |
| 193 | + # unknown → ignore |
| 194 | + continue |
| 195 | + return symptom_vector, severity_vector, present_ids |
| 196 | + |
| 197 | + |
| 198 | +def _compute_adjusted_probs(symptom_vector: list[int], severity_vector: list[float], present_ids: list[int]) -> list[float]: |
| 199 | + if model.network is None: |
| 200 | + _ensure_model_loaded() |
| 201 | + features = symptom_vector + severity_vector |
| 202 | + base = model._predict_proba(features) |
| 203 | + adjusted = model._apply_clinical_rules(base, present_ids, severity_vector, has_test_results=None) |
| 204 | + total = sum(adjusted) |
| 205 | + return [p / total for p in adjusted] if total else adjusted |
| 206 | + |
| 207 | + |
| 208 | +def _entropy(probs: list[float]) -> float: |
| 209 | + import math |
| 210 | + eps = 1e-12 |
| 211 | + return -sum(p * math.log(max(p, eps)) for p in probs) |
| 212 | + |
| 213 | + |
| 214 | +def _select_next_symptom(disease_probs: list[float], asked: set[int]) -> int | None: |
| 215 | + # Build mapping disease_id -> prob |
| 216 | + d_ids = list(DISEASES_V2.keys()) |
| 217 | + p_map = {did: disease_probs[did] for did in d_ids} |
| 218 | + h_before = _entropy(list(p_map.values())) |
| 219 | + best_symptom = None |
| 220 | + best_eig = -1.0 |
| 221 | + # Precompute per-disease symptom frequencies |
| 222 | + for sid in range(30): |
| 223 | + if sid in asked: |
| 224 | + continue |
| 225 | + # P(yes|d) |
| 226 | + py_d = {did: DISEASES_V2[did].get("symptom_patterns", {}).get(sid, {}).get("frequency", 0.0) for did in d_ids} |
| 227 | + # Priors for yes/no |
| 228 | + p_yes = sum(p_map[did] * py_d[did] for did in d_ids) |
| 229 | + p_no = 1.0 - p_yes |
| 230 | + if p_yes <= 1e-9 or p_no <= 1e-9: |
| 231 | + eig = 0.0 |
| 232 | + else: |
| 233 | + # Posteriors |
| 234 | + post_yes = [] |
| 235 | + post_no = [] |
| 236 | + for did in d_ids: |
| 237 | + post_yes.append((p_map[did] * py_d[did]) / p_yes) |
| 238 | + post_no.append((p_map[did] * (1.0 - py_d[did])) / p_no) |
| 239 | + h_yes = _entropy(post_yes) |
| 240 | + h_no = _entropy(post_no) |
| 241 | + eig = h_before - (p_yes * h_yes + p_no * h_no) |
| 242 | + if eig > best_eig: |
| 243 | + best_eig = eig |
| 244 | + best_symptom = sid |
| 245 | + return best_symptom |
| 246 | + |
| 247 | + |
| 248 | +def _session_should_stop(probs: list[float], num_questions: int, threshold: float, max_q: int) -> bool: |
| 249 | + return (max(probs) >= threshold) or (num_questions >= max_q) |
| 250 | + |
| 251 | + |
| 252 | +def _build_next_question(sid: int | None) -> dict | None: |
| 253 | + if sid is None: |
| 254 | + return None |
| 255 | + meta = SYMPTOMS.get(sid, {}) |
| 256 | + return { |
| 257 | + "symptom_id": sid, |
| 258 | + "name": meta.get("name"), |
| 259 | + "medical_term": meta.get("medical_term"), |
| 260 | + "icd_10": meta.get("icd_10"), |
| 261 | + } |
| 262 | + |
| 263 | + |
| 264 | +@app.post("/api/v2/adaptive/start") |
| 265 | +def adaptive_start(req: AdaptiveStartRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)): |
| 266 | + if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc": |
| 267 | + _auth_check(x_api_key) |
| 268 | + # Create session |
| 269 | + session_id = str(uuid.uuid4()) |
| 270 | + threshold = req.threshold if req.threshold is not None else float(os.environ.get("MDM_ADAPTIVE_CONFIDENCE", "0.85")) |
| 271 | + max_q = req.max_questions if req.max_questions is not None else int(os.environ.get("MDM_ADAPTIVE_MAX_Q", "10")) |
| 272 | + answers: Dict[int, dict] = {} |
| 273 | + # Seed prior answers |
| 274 | + if req.prior_answers: |
| 275 | + for key, val in req.prior_answers.items(): |
| 276 | + sid = _symptom_id_from_key(key) |
| 277 | + if sid is None: |
| 278 | + continue |
| 279 | + if isinstance(val, (int, float)): |
| 280 | + answers[sid] = {"answer": "yes", "severity": float(val)} |
| 281 | + elif isinstance(val, str): |
| 282 | + answers[sid] = {"answer": val.lower(), "severity": None} |
| 283 | + _ADAPTIVE_SESSIONS[session_id] = { |
| 284 | + "answers": answers, |
| 285 | + "threshold": threshold, |
| 286 | + "max_q": max_q, |
| 287 | + "num_q": 0, |
| 288 | + } |
| 289 | + # Compute next question |
| 290 | + sv, sev, present = _answers_to_vectors(answers) |
| 291 | + probs = _compute_adjusted_probs(sv, sev, present) |
| 292 | + sid_next = _select_next_symptom(probs, set(answers.keys())) |
| 293 | + return AdaptiveStartResponse(session_id=session_id, next_question=_build_next_question(sid_next)) |
| 294 | + |
| 295 | + |
| 296 | +@app.post("/api/v2/adaptive/answer") |
| 297 | +def adaptive_answer(req: AdaptiveAnswerRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)): |
| 298 | + if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc": |
| 299 | + _auth_check(x_api_key) |
| 300 | + sess = _ADAPTIVE_SESSIONS.get(req.session_id) |
| 301 | + if not sess: |
| 302 | + raise HTTPException(status_code=404, detail="Session not found") |
| 303 | + sid = _symptom_id_from_key(req.question) |
| 304 | + if sid is None: |
| 305 | + raise HTTPException(status_code=400, detail="Invalid question") |
| 306 | + ans = req.answer.lower() |
| 307 | + if ans not in {"yes", "no", "unknown"}: |
| 308 | + raise HTTPException(status_code=400, detail="Invalid answer") |
| 309 | + sess["answers"][sid] = {"answer": ans, "severity": req.severity} |
| 310 | + sess["num_q"] = int(sess.get("num_q", 0)) + 1 |
| 311 | + # Recompute |
| 312 | + sv, sev, present = _answers_to_vectors(sess["answers"]) |
| 313 | + probs = _compute_adjusted_probs(sv, sev, present) |
| 314 | + if _session_should_stop(probs, sess["num_q"], sess["threshold"], sess["max_q"]): |
| 315 | + # Build diagnosis using current answers (convert to name: severity 0-10) |
| 316 | + symptom_dict = {} |
| 317 | + for sid_k, info in sess["answers"].items(): |
| 318 | + if info.get("answer") == "yes": |
| 319 | + name = SYMPTOMS.get(sid_k, {}).get("name") |
| 320 | + if name: |
| 321 | + val = info.get("severity") |
| 322 | + symptom_dict[name] = float(val) if val is not None else 6.0 |
| 323 | + results = model.diagnose_with_reasoning(symptom_dict) |
| 324 | + return AdaptiveAnswerResponse(session_id=req.session_id, finished=True, next_question=None, results=results) |
| 325 | + # Else ask next |
| 326 | + sid_next = _select_next_symptom(probs, set(sess["answers"].keys())) |
| 327 | + return AdaptiveAnswerResponse(session_id=req.session_id, finished=False, next_question=_build_next_question(sid_next), results=None) |
| 328 | + |
| 329 | + |
| 330 | +class AdaptiveFinishRequest(BaseModel): |
| 331 | + session_id: str |
| 332 | + |
| 333 | + |
| 334 | +@app.post("/api/v2/adaptive/finish") |
| 335 | +def adaptive_finish(req: AdaptiveFinishRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)): |
| 336 | + if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc": |
| 337 | + _auth_check(x_api_key) |
| 338 | + sess = _ADAPTIVE_SESSIONS.pop(req.session_id, None) |
| 339 | + if not sess: |
| 340 | + raise HTTPException(status_code=404, detail="Session not found") |
| 341 | + symptom_dict = {} |
| 342 | + for sid_k, info in sess["answers"].items(): |
| 343 | + if info.get("answer") == "yes": |
| 344 | + name = SYMPTOMS.get(sid_k, {}).get("name") |
| 345 | + if name: |
| 346 | + val = info.get("severity") |
| 347 | + symptom_dict[name] = float(val) if val is not None else 6.0 |
| 348 | + results = model.diagnose_with_reasoning(symptom_dict) |
| 349 | + return {"session_id": req.session_id, "results": results} |
| 350 | + |
0 commit comments