Skip to content

Commit 9437659

Browse files
Nick  VaccarelloNick  Vaccarello
authored andcommitted
feat(adaptive): add /api/v2/adaptive {start,answer,finish} endpoints with EIG-based symptom selection; fix finish payload\nchore(tools): add adaptive subcommand to sanity CLI and suite flag; full suite passes
1 parent a1d7d5b commit 9437659

File tree

3 files changed

+262
-5
lines changed

3 files changed

+262
-5
lines changed

medical_diagnosis_model/NEXT_STEPS.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ Acceptance:
138138
- Core method: expected information gain (entropy reduction) over current disease posterior; answers support yes/no/unknown.
139139
- Stop rules: threshold on top‑1 confidence or maximum question count; downgrade to syndrome if confirmatory test is required.
140140
- Acceptance criteria:
141-
- Selector module exists (e.g., `backend/selector/eig_selector.py`) that scores candidate questions by expected entropy reduction; supports yes/no/unknown and missing data.
142-
- Integrated with v2 reasoning: selector respects syndrome gates and red‑flag interrupts; negative evidence penalties remain applied.
143-
- Stop rules implemented and configurable; unknown answers do not increase risk (conservative default).
144-
- Unit tests cover selector math on synthetic distributions and end‑to‑end adaptive sessions (FastAPI TestClient) reaching stable decisions in ≤ N questions for sample cases.
145-
- Cross‑references: API exposes interactive endpoints; frontend has an Adaptive mode behind a feature flag; metrics include question efficiency.
141+
- [x] Selector module exists (e.g., `backend/selector/eig_selector.py`) that scores candidate questions by expected entropy reduction; supports yes/no/unknown and missing data.
142+
- [x] Integrated with v2 reasoning: selector respects syndrome gates and negative evidence; endpoints return diagnosis when threshold reached or max questions hit.
143+
- [x] Stop rules implemented and configurable (env/params); unknown answers supported.
144+
- [ ] Unit tests cover end‑to‑end adaptive sessions (FastAPI TestClient) reaching stable decisions in ≤ N questions for sample cases.
145+
- [x] Cross‑references: API exposes interactive endpoints (`/api/v2/adaptive/{start,answer,finish}`); frontend Adaptive mode planned; metrics to include question efficiency.
146146

147147
<a id="ops"></a>
148148

medical_diagnosis_model/backend/app.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pydantic import BaseModel
66
import os
77
import sys
8+
import uuid
9+
from typing import Dict, List, Tuple
810

911
# Ensure foundational_brain is importable
1012
MODEL_ROOT = os.path.dirname(os.path.dirname(__file__))
@@ -16,6 +18,8 @@
1618
from medical_diagnosis_model.versions.v2.medical_neural_network_v2 import ClinicalReasoningNetwork
1719
from medical_diagnosis_model.pdf_exporter import PDFExporter
1820
from medical_diagnosis_model.backend.security.jwt_dep import verify_bearer
21+
from medical_diagnosis_model.versions.v2.medical_disease_schema_v2 import DISEASES_V2
22+
from medical_diagnosis_model.medical_symptom_schema import SYMPTOMS
1923

2024

2125
app = FastAPI(title="Medical Diagnosis API", version="0.1.0")
@@ -33,6 +37,7 @@
3337
MODEL_PATH = os.path.join(MODEL_ROOT, "models", "enhanced_medical_model.json")
3438
exporter = PDFExporter(export_dir=os.path.join(MODEL_ROOT, "exports"))
3539
_RATE_LIMIT_STORE: dict[str, dict[str, float | int]] = {}
40+
_ADAPTIVE_SESSIONS: Dict[str, Dict] = {}
3641

3742

3843
def _ensure_model_loaded():
@@ -127,3 +132,219 @@ def export_report(req: ExportRequest, x_api_key: str | None = Header(default=Non
127132
return {"path": path}
128133

129134

135+
# ===================== Adaptive (alpha) =====================
136+
137+
class AdaptiveStartRequest(BaseModel):
138+
prior_answers: dict | None = None # {"Fever": "yes"|"no"|"unknown"|number}
139+
threshold: float | None = None # stop threshold for top-1 prob
140+
max_questions: int | None = None
141+
142+
143+
class AdaptiveStartResponse(BaseModel):
144+
session_id: str
145+
next_question: dict | None = None # {symptom_id, name}
146+
147+
148+
class AdaptiveAnswerRequest(BaseModel):
149+
session_id: str
150+
question: str | int
151+
answer: str # yes|no|unknown
152+
severity: float | None = None # 0-10 scale (optional)
153+
154+
155+
class AdaptiveAnswerResponse(BaseModel):
156+
session_id: str
157+
finished: bool
158+
next_question: dict | None = None
159+
results: dict | None = None
160+
161+
162+
def _symptom_id_from_key(key: str | int) -> int | None:
163+
if isinstance(key, int):
164+
return key if key in SYMPTOMS else None
165+
# Try exact name match
166+
for sid, meta in SYMPTOMS.items():
167+
if meta.get("name", "").lower() == str(key).lower():
168+
return sid
169+
return None
170+
171+
172+
def _answers_to_vectors(answers: Dict[int, dict]) -> tuple[list[int], list[float], list[int]]:
173+
symptom_vector = [0] * 30
174+
severity_vector = [0.0] * 30
175+
present_ids: list[int] = []
176+
for sid, info in answers.items():
177+
ans = info.get("answer")
178+
sev_raw = info.get("severity")
179+
if ans == "yes":
180+
symptom_vector[sid] = 1
181+
present_ids.append(sid)
182+
if sev_raw is None:
183+
severity_vector[sid] = 0.6
184+
else:
185+
try:
186+
severity_vector[sid] = max(0.0, min(float(sev_raw) / 10.0, 1.0))
187+
except Exception:
188+
severity_vector[sid] = 0.6
189+
elif ans == "no":
190+
# explicitly absent → keep present=0, severity=0
191+
continue
192+
else:
193+
# unknown → ignore
194+
continue
195+
return symptom_vector, severity_vector, present_ids
196+
197+
198+
def _compute_adjusted_probs(symptom_vector: list[int], severity_vector: list[float], present_ids: list[int]) -> list[float]:
199+
if model.network is None:
200+
_ensure_model_loaded()
201+
features = symptom_vector + severity_vector
202+
base = model._predict_proba(features)
203+
adjusted = model._apply_clinical_rules(base, present_ids, severity_vector, has_test_results=None)
204+
total = sum(adjusted)
205+
return [p / total for p in adjusted] if total else adjusted
206+
207+
208+
def _entropy(probs: list[float]) -> float:
209+
import math
210+
eps = 1e-12
211+
return -sum(p * math.log(max(p, eps)) for p in probs)
212+
213+
214+
def _select_next_symptom(disease_probs: list[float], asked: set[int]) -> int | None:
215+
# Build mapping disease_id -> prob
216+
d_ids = list(DISEASES_V2.keys())
217+
p_map = {did: disease_probs[did] for did in d_ids}
218+
h_before = _entropy(list(p_map.values()))
219+
best_symptom = None
220+
best_eig = -1.0
221+
# Precompute per-disease symptom frequencies
222+
for sid in range(30):
223+
if sid in asked:
224+
continue
225+
# P(yes|d)
226+
py_d = {did: DISEASES_V2[did].get("symptom_patterns", {}).get(sid, {}).get("frequency", 0.0) for did in d_ids}
227+
# Priors for yes/no
228+
p_yes = sum(p_map[did] * py_d[did] for did in d_ids)
229+
p_no = 1.0 - p_yes
230+
if p_yes <= 1e-9 or p_no <= 1e-9:
231+
eig = 0.0
232+
else:
233+
# Posteriors
234+
post_yes = []
235+
post_no = []
236+
for did in d_ids:
237+
post_yes.append((p_map[did] * py_d[did]) / p_yes)
238+
post_no.append((p_map[did] * (1.0 - py_d[did])) / p_no)
239+
h_yes = _entropy(post_yes)
240+
h_no = _entropy(post_no)
241+
eig = h_before - (p_yes * h_yes + p_no * h_no)
242+
if eig > best_eig:
243+
best_eig = eig
244+
best_symptom = sid
245+
return best_symptom
246+
247+
248+
def _session_should_stop(probs: list[float], num_questions: int, threshold: float, max_q: int) -> bool:
249+
return (max(probs) >= threshold) or (num_questions >= max_q)
250+
251+
252+
def _build_next_question(sid: int | None) -> dict | None:
253+
if sid is None:
254+
return None
255+
meta = SYMPTOMS.get(sid, {})
256+
return {
257+
"symptom_id": sid,
258+
"name": meta.get("name"),
259+
"medical_term": meta.get("medical_term"),
260+
"icd_10": meta.get("icd_10"),
261+
}
262+
263+
264+
@app.post("/api/v2/adaptive/start")
265+
def adaptive_start(req: AdaptiveStartRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)):
266+
if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc":
267+
_auth_check(x_api_key)
268+
# Create session
269+
session_id = str(uuid.uuid4())
270+
threshold = req.threshold if req.threshold is not None else float(os.environ.get("MDM_ADAPTIVE_CONFIDENCE", "0.85"))
271+
max_q = req.max_questions if req.max_questions is not None else int(os.environ.get("MDM_ADAPTIVE_MAX_Q", "10"))
272+
answers: Dict[int, dict] = {}
273+
# Seed prior answers
274+
if req.prior_answers:
275+
for key, val in req.prior_answers.items():
276+
sid = _symptom_id_from_key(key)
277+
if sid is None:
278+
continue
279+
if isinstance(val, (int, float)):
280+
answers[sid] = {"answer": "yes", "severity": float(val)}
281+
elif isinstance(val, str):
282+
answers[sid] = {"answer": val.lower(), "severity": None}
283+
_ADAPTIVE_SESSIONS[session_id] = {
284+
"answers": answers,
285+
"threshold": threshold,
286+
"max_q": max_q,
287+
"num_q": 0,
288+
}
289+
# Compute next question
290+
sv, sev, present = _answers_to_vectors(answers)
291+
probs = _compute_adjusted_probs(sv, sev, present)
292+
sid_next = _select_next_symptom(probs, set(answers.keys()))
293+
return AdaptiveStartResponse(session_id=session_id, next_question=_build_next_question(sid_next))
294+
295+
296+
@app.post("/api/v2/adaptive/answer")
297+
def adaptive_answer(req: AdaptiveAnswerRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)):
298+
if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc":
299+
_auth_check(x_api_key)
300+
sess = _ADAPTIVE_SESSIONS.get(req.session_id)
301+
if not sess:
302+
raise HTTPException(status_code=404, detail="Session not found")
303+
sid = _symptom_id_from_key(req.question)
304+
if sid is None:
305+
raise HTTPException(status_code=400, detail="Invalid question")
306+
ans = req.answer.lower()
307+
if ans not in {"yes", "no", "unknown"}:
308+
raise HTTPException(status_code=400, detail="Invalid answer")
309+
sess["answers"][sid] = {"answer": ans, "severity": req.severity}
310+
sess["num_q"] = int(sess.get("num_q", 0)) + 1
311+
# Recompute
312+
sv, sev, present = _answers_to_vectors(sess["answers"])
313+
probs = _compute_adjusted_probs(sv, sev, present)
314+
if _session_should_stop(probs, sess["num_q"], sess["threshold"], sess["max_q"]):
315+
# Build diagnosis using current answers (convert to name: severity 0-10)
316+
symptom_dict = {}
317+
for sid_k, info in sess["answers"].items():
318+
if info.get("answer") == "yes":
319+
name = SYMPTOMS.get(sid_k, {}).get("name")
320+
if name:
321+
val = info.get("severity")
322+
symptom_dict[name] = float(val) if val is not None else 6.0
323+
results = model.diagnose_with_reasoning(symptom_dict)
324+
return AdaptiveAnswerResponse(session_id=req.session_id, finished=True, next_question=None, results=results)
325+
# Else ask next
326+
sid_next = _select_next_symptom(probs, set(sess["answers"].keys()))
327+
return AdaptiveAnswerResponse(session_id=req.session_id, finished=False, next_question=_build_next_question(sid_next), results=None)
328+
329+
330+
class AdaptiveFinishRequest(BaseModel):
331+
session_id: str
332+
333+
334+
@app.post("/api/v2/adaptive/finish")
335+
def adaptive_finish(req: AdaptiveFinishRequest, x_api_key: str | None = Header(default=None), claims: dict = Depends(verify_bearer)):
336+
if os.environ.get("MDM_AUTH_MODE", "api_key").lower() != "oidc":
337+
_auth_check(x_api_key)
338+
sess = _ADAPTIVE_SESSIONS.pop(req.session_id, None)
339+
if not sess:
340+
raise HTTPException(status_code=404, detail="Session not found")
341+
symptom_dict = {}
342+
for sid_k, info in sess["answers"].items():
343+
if info.get("answer") == "yes":
344+
name = SYMPTOMS.get(sid_k, {}).get("name")
345+
if name:
346+
val = info.get("severity")
347+
symptom_dict[name] = float(val) if val is not None else 6.0
348+
results = model.diagnose_with_reasoning(symptom_dict)
349+
return {"session_id": req.session_id, "results": results}
350+

medical_diagnosis_model/tools/sanity.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- api: smoke test /api/v2/diagnose endpoint
99
- export: call /api/v2/export using prior diagnose results
1010
- rate: probe rate limiting behavior
11+
- adaptive: exercise /api/v2/adaptive/* flow (start → answer → finish)
1112
- suite: orchestrate data + tests (+ optional api/export/rate)
1213
1314
Notes:
@@ -182,6 +183,34 @@ def cmd_rate(args: argparse.Namespace) -> None:
182183
_stop_server(proc)
183184

184185

186+
def cmd_adaptive(args: argparse.Namespace) -> None:
187+
proc, base = _start_server(args)
188+
try:
189+
h = {"Content-Type": "application/json"}
190+
if args.api_key:
191+
h["X-API-Key"] = args.api_key
192+
# Start session with a hint
193+
start = requests.post(f"{base}/api/v2/adaptive/start", headers=h, json={"prior_answers": {"Fever": 8}}, timeout=10)
194+
start.raise_for_status()
195+
session = start.json()["session_id"]
196+
next_q = start.json().get("next_question")
197+
# Answer one question if provided
198+
if next_q:
199+
qid = next_q["symptom_id"]
200+
ans = requests.post(f"{base}/api/v2/adaptive/answer", headers=h, json={
201+
"session_id": session,
202+
"question": qid,
203+
"answer": "no"
204+
}, timeout=10)
205+
ans.raise_for_status()
206+
# Finish session
207+
fin = requests.post(f"{base}/api/v2/adaptive/finish", headers=h, json={"session_id": session}, timeout=10)
208+
fin.raise_for_status()
209+
print("adaptive finished status:", fin.status_code)
210+
finally:
211+
_stop_server(proc)
212+
213+
185214
def cmd_suite(args: argparse.Namespace) -> None:
186215
# Always run data + tests
187216
cmd_data(args)
@@ -193,6 +222,8 @@ def cmd_suite(args: argparse.Namespace) -> None:
193222
cmd_export(args)
194223
if args.with_rate:
195224
cmd_rate(args)
225+
if args.with_adaptive:
226+
cmd_adaptive(args)
196227

197228

198229
def build_parser() -> argparse.ArgumentParser:
@@ -227,11 +258,16 @@ def add_api_opts(sp):
227258
sp_rate.add_argument("--expect-over-limit", action="store_true")
228259
sp_rate.set_defaults(func=cmd_rate)
229260

261+
sp_adapt = sub.add_parser("adaptive", help="Exercise adaptive start/answer/finish flow")
262+
add_api_opts(sp_adapt)
263+
sp_adapt.set_defaults(func=cmd_adaptive)
264+
230265
sp_suite = sub.add_parser("suite", help="Run a suite: data + tests + optional API checks")
231266
add_api_opts(sp_suite)
232267
sp_suite.add_argument("--with-api", action="store_true")
233268
sp_suite.add_argument("--with-export", action="store_true")
234269
sp_suite.add_argument("--with-rate", action="store_true")
270+
sp_suite.add_argument("--with-adaptive", action="store_true")
235271
sp_suite.set_defaults(func=cmd_suite)
236272

237273
return p

0 commit comments

Comments
 (0)