Skip to content

Commit 7278986

Browse files
Nick  VaccarelloNick  Vaccarello
authored andcommitted
feat(data): add v0.2 generator with explicit negatives and balance; feat(train): train_from_jsonl and pipeline script; backend: prefer v02 model by default if present
1 parent 160430a commit 7278986

File tree

4 files changed

+241
-4
lines changed

4 files changed

+241
-4
lines changed

medical_diagnosis_model/backend/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
allow_headers=["*"],
3535
)
3636
model = ClinicalReasoningNetwork(hidden_neurons=25, learning_rate=0.3, epochs=1000)
37-
MODEL_PATH = os.path.join(MODEL_ROOT, "models", "enhanced_medical_model.json")
37+
# Prefer v0.2 model if present; allow env override
38+
DEFAULT_MODEL = os.path.join(MODEL_ROOT, "models", "enhanced_medical_model.json")
39+
V02_MODEL = os.path.join(MODEL_ROOT, "models", "enhanced_medical_model_v02.json")
40+
MODEL_PATH = os.environ.get("MDM_MODEL_PATH") or (V02_MODEL if os.path.exists(V02_MODEL) else DEFAULT_MODEL)
3841
exporter = PDFExporter(export_dir=os.path.join(MODEL_ROOT, "exports"))
3942
_RATE_LIMIT_STORE: dict[str, dict[str, float | int]] = {}
4043
_ADAPTIVE_SESSIONS: Dict[str, Dict] = {}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Generate v0.2 balanced JSONL training data with explicit negative evidence.
4+
5+
Format per line:
6+
{
7+
"symptoms": {"Cough": 6, "Runny Nose": 5, ...},
8+
"label_name": "Viral Upper Respiratory Infection"
9+
}
10+
"""
11+
from __future__ import annotations
12+
13+
import json
14+
import os
15+
import random
16+
from pathlib import Path
17+
from typing import Dict, Tuple, List
18+
19+
20+
def _load_schema():
21+
# Resolve imports locally
22+
import sys
23+
here = Path(__file__).resolve().parent
24+
model_root = here.parent
25+
repo_root = model_root.parent
26+
for p in (str(repo_root), str(model_root)):
27+
if p not in sys.path:
28+
sys.path.append(p)
29+
from versions.v2.medical_disease_schema_v2 import DISEASES_V2
30+
from medical_symptom_schema import SYMPTOMS
31+
return DISEASES_V2, SYMPTOMS
32+
33+
34+
def _sample_case(disease_id: int, diseases: dict, symptoms: dict, explicit_neg: bool) -> Tuple[Dict[str, float], str]:
35+
dis = diseases[disease_id]
36+
name = dis["name"]
37+
pats = dis.get("symptom_patterns", {})
38+
out: Dict[str, float] = {}
39+
# Positive sampling from patterns
40+
for sid, pat in pats.items():
41+
if sid not in symptoms:
42+
continue
43+
freq = pat.get("frequency", 0.0)
44+
sev_lo, sev_hi = pat.get("severity_range", (0.2, 0.6))
45+
if random.random() < freq:
46+
sev = random.uniform(sev_lo, sev_hi)
47+
out[symptoms[sid]["name"]] = round(min(max(sev * 10.0, 0.0), 10.0), 1)
48+
49+
# Mild/early tweak: 30% chance reduce severities
50+
if random.random() < 0.3:
51+
for k in list(out.keys()):
52+
out[k] = round(out[k] * random.uniform(0.5, 0.8), 1)
53+
54+
# Explicit negatives across syndromes
55+
if explicit_neg:
56+
# For respiratory: ensure GU keys absent; for GU: reduce respiratory signals
57+
if name in ("Viral Upper Respiratory Infection", "Influenza-like Illness", "COVID-19-like Illness", "Viral Syndrome", "Pneumonia Syndrome"):
58+
for sid in (26, 27): # Frequency, Dysuria
59+
out.setdefault(symptoms[sid]["name"], 0.0)
60+
if name == "Urinary Tract Infection":
61+
for sid in (3, 7, 8): # Cough, Rhinorrhea, Congestion
62+
out.setdefault(symptoms[sid]["name"], 0.0)
63+
64+
return out, name
65+
66+
67+
def generate_balanced(per_disease: int = 200, seed: int = 42) -> List[Dict]:
68+
random.seed(seed)
69+
DISEASES_V2, SYMPTOMS = _load_schema()
70+
# Focus set: common respiratory + GU UTI
71+
target_names = {
72+
"Viral Upper Respiratory Infection",
73+
"Influenza-like Illness",
74+
"COVID-19-like Illness",
75+
"Viral Syndrome",
76+
"Urinary Tract Infection",
77+
}
78+
target_ids = [did for did, d in DISEASES_V2.items() if d["name"] in target_names]
79+
data: List[Dict] = []
80+
for did in target_ids:
81+
for _ in range(per_disease):
82+
s, label = _sample_case(did, DISEASES_V2, SYMPTOMS, explicit_neg=True)
83+
data.append({"symptoms": s, "label_name": label})
84+
random.shuffle(data)
85+
return data
86+
87+
88+
def main() -> int:
89+
root = Path(__file__).resolve().parents[1]
90+
out_dir = root / "data" / "v02"
91+
out_dir.mkdir(parents=True, exist_ok=True)
92+
path = out_dir / "cases_v02.jsonl"
93+
data = generate_balanced(per_disease=150)
94+
with path.open("w", encoding="utf-8") as f:
95+
for row in data:
96+
f.write(json.dumps(row) + "\n")
97+
print(f"Wrote {len(data)} cases to {path}")
98+
return 0
99+
100+
101+
if __name__ == "__main__":
102+
raise SystemExit(main())
103+
104+
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env python3
2+
"""
3+
One-shot training pipeline:
4+
- Generate v0.2 balanced dataset (explicit negatives)
5+
- Train v2 from JSONL
6+
- Calibrate and save model to models/enhanced_medical_model_v02.json
7+
- Optionally run a quick confusion summary on held-out set (counts only)
8+
"""
9+
from __future__ import annotations
10+
11+
import argparse
12+
import json
13+
from pathlib import Path
14+
15+
16+
def _setup_paths() -> None:
17+
import os, sys
18+
here = Path(__file__).resolve().parent
19+
model_root = here.parent
20+
repo_root = model_root.parent
21+
for p in (str(repo_root), str(model_root)):
22+
if p not in sys.path:
23+
sys.path.append(p)
24+
25+
26+
def generate_dataset(per_disease: int) -> Path:
27+
from data.generate_v02 import generate_balanced
28+
root = Path(__file__).resolve().parents[1]
29+
out = root / "data" / "v02"
30+
out.mkdir(parents=True, exist_ok=True)
31+
path = out / "cases_v02.jsonl"
32+
data = generate_balanced(per_disease=per_disease)
33+
with path.open("w", encoding="utf-8") as f:
34+
for row in data:
35+
f.write(json.dumps(row) + "\n")
36+
return path
37+
38+
39+
def train_model(jsonl_path: Path, epochs: int) -> Path:
40+
from versions.v2.medical_neural_network_v2 import ClinicalReasoningNetwork
41+
m = ClinicalReasoningNetwork(hidden_neurons=25, learning_rate=0.3, epochs=epochs)
42+
m.train_from_jsonl(str(jsonl_path), verbose=False)
43+
out = Path(__file__).resolve().parents[1] / "models" / "enhanced_medical_model_v02.json"
44+
m.save_model(str(out))
45+
return out
46+
47+
48+
def main() -> int:
49+
_setup_paths()
50+
ap = argparse.ArgumentParser(description="Train v0.2 model from generated data")
51+
ap.add_argument("--per-disease", type=int, default=200)
52+
ap.add_argument("--epochs", type=int, default=5000)
53+
args = ap.parse_args()
54+
55+
jsonl = generate_dataset(args.per_disease)
56+
print(f"Generated dataset: {jsonl}")
57+
model_path = train_model(jsonl, args.epochs)
58+
print(f"Saved model: {model_path}")
59+
print("Set MDM_MODEL_PATH to use this model in the API if not picked by default.")
60+
return 0
61+
62+
63+
if __name__ == "__main__":
64+
raise SystemExit(main())
65+
66+

medical_diagnosis_model/versions/v2/medical_neural_network_v2.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,15 @@ def diagnose_with_reasoning(self, symptoms_dict, has_test_results=None):
162162
for symptom_name, severity in symptoms_dict.items():
163163
sid, symptom = get_symptom_by_name(symptom_name)
164164
if sid is not None and sid < self.num_symptoms:
165-
symptom_vector[sid] = 1
166-
severity_vector[sid] = severity / 10.0
167-
symptom_ids.append(sid)
165+
try:
166+
sev_norm = float(severity) / 10.0
167+
except Exception:
168+
sev_norm = 0.0
169+
# Treat zero (or negative) severity as absent
170+
if sev_norm > 0.0:
171+
symptom_vector[sid] = 1
172+
severity_vector[sid] = sev_norm
173+
symptom_ids.append(sid)
168174

169175
# Determine syndrome
170176
syndrome = get_syndrome_from_symptoms(symptom_ids)
@@ -462,6 +468,61 @@ def load_model(self, filename="models/enhanced_medical_model.json"):
462468
self.network.append(rebuilt)
463469
print(f"Model loaded from {filename}")
464470

471+
# ===== Training from JSONL (v0.2) =====
472+
def train_from_jsonl(self, jsonl_path: str, seed: int = 42, verbose: bool = True):
473+
import random
474+
random.seed(seed)
475+
# Load data
476+
rows = []
477+
import json
478+
with open(jsonl_path, "r", encoding="utf-8") as f:
479+
for line in f:
480+
line = line.strip()
481+
if not line:
482+
continue
483+
obj = json.loads(line)
484+
rows.append(obj)
485+
# Build vectors
486+
dataset = []
487+
for obj in rows:
488+
sym = obj.get("symptoms", {})
489+
label_name = obj.get("label_name")
490+
# Map label to id
491+
label_id = None
492+
for did, d in DISEASES_V2.items():
493+
if d.get("name") == label_name:
494+
label_id = did
495+
break
496+
if label_id is None:
497+
continue
498+
symptom_vector = [0] * self.num_symptoms
499+
severity_vector = [0.0] * self.num_symptoms
500+
for name, sev in sym.items():
501+
sid, _ = get_symptom_by_name(name)
502+
if sid is None or sid >= self.num_symptoms:
503+
continue
504+
try:
505+
sevn = float(sev) / 10.0
506+
except Exception:
507+
sevn = 0.0
508+
if sevn > 0.0:
509+
symptom_vector[sid] = 1
510+
severity_vector[sid] = sevn
511+
features = symptom_vector + severity_vector + [label_id]
512+
dataset.append(features)
513+
# Shuffle and split
514+
random.shuffle(dataset)
515+
split = int(0.8 * len(dataset))
516+
train_set = dataset[:split]
517+
val_set = dataset[split:]
518+
# Init and train
519+
self.network = initialize_network(self.num_features, self.hidden_neurons, self.num_diseases)
520+
history = self._train_softmax_cross_entropy(self.network, train_set, val_set, verbose=verbose)
521+
self.temperature = self._calibrate_temperature(val_set)
522+
if verbose:
523+
print(f"Calibration: selected T={self.temperature:.2f}")
524+
return history
525+
465526
def _apply_clinical_rules(self, nn_outputs, symptom_ids, severity_vector, has_test_results):
466527
"""Apply clinical decision rules to adjust probabilities"""
467528
adjusted = nn_outputs.copy()
@@ -572,6 +633,9 @@ def _idx(name: str):
572633
logits[uri_idx] += 1.5
573634
if fever < 0.6 and myalgia < 0.6:
574635
logits[uri_idx] += 0.5
636+
# Negative evidence: GU keys absent with strong URI pattern present
637+
if uti_idx is not None and (rhinorrhea > 0.3 or congestion > 0.3) and (26 not in symptom_ids) and (27 not in symptom_ids):
638+
logits[uti_idx] -= 6.0
575639

576640
# ILI: high fever + myalgia (+/- severe fatigue)
577641
if ili_idx is not None:

0 commit comments

Comments
 (0)