-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
114 lines (93 loc) · 3.18 KB
/
run.py
File metadata and controls
114 lines (93 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import csv
import argparse
import requests
from dotenv import load_dotenv
from db import get_db_conn, fetch_notes_by_query, fetch_all_notes
load_dotenv("config.env")
MODEL_SERVE_URL = "http://" + os.getenv("MODEL_SERVE_URL", "127.0.0.1:8000")
MODEL_SERVE_TIMEOUT = int(os.getenv("MODEL_SERVE_TIMEOUT", 60))
DEFAULT_BATCH_SIZE = int(os.getenv("BATCH_SIZE", 100))
def batched(iterable, batch_size):
batch = []
for item in iterable:
batch.append(item)
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
def run_model(
model: str,
query: str | None,
batch_size: int,
output_csv: str,
):
model_url = f"{MODEL_SERVE_URL.rstrip('/')}/process_bulk"
def modelserve_custom(texts):
headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
response = requests.post(
model_url,
headers=headers,
json=texts
)
response.raise_for_status()
return response.json()
results = []
processed = 0
with get_db_conn() as conn:
notes = fetch_notes_by_query(conn, query) if query else fetch_all_notes(conn)
notes = list(notes)
total = len(notes)
print(f"Running model '{model}' on {total} notes...")
for batch in batched(notes, batch_size):
texts = [n.get("note_text", "") or "" for n in batch]
note_ids = [n.get("note_id") for n in batch]
try:
outputs = modelserve_custom(texts)
for note_id, output in zip(note_ids, outputs):
results.append({
"note_id": note_id,
"modelserve_output": json.dumps(output),
"error": ""
})
except Exception as e:
for note_id in note_ids:
results.append({
"note_id": note_id,
"modelserve_output": "",
"error": str(e)
})
processed += len(batch)
print(f"Processed {processed}/{total} notes")
os.makedirs(os.path.dirname(output_csv), exist_ok=True)
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["note_id", "modelserve_output", "error"]
)
writer.writeheader()
for r in results:
writer.writerow(r)
print(f"Completed successfully. Results written to {output_csv}")
def main():
parser = argparse.ArgumentParser(
description="Run NLP model over OMOP notes and write results to CSV"
)
parser.add_argument("--model", required=True)
parser.add_argument("--query", required=True)
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
parser.add_argument("--output_csv", required=True)
args = parser.parse_args()
run_model(
model=args.model,
query=args.query,
batch_size=args.batch_size,
output_csv=args.output_csv
)
if __name__ == "__main__":
main()