diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..e85121a --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,25 @@ +name: Run pytest + +on: + push: + branches: [main, dev] + pull_request: + branches: [main, dev] + types: ["opened", "reopened", "synchronize", "ready_for_review", "draft"] + workflow_dispatch: + +jobs: + build: + name: Run pytest + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v5 + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Install dependencies + run: uv sync --locked --all-extras --dev + + - name: Run the tests + run: uv run pytest tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a991089..f202a2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: ruff args: - --fix - - waveform_controller + - ./ - id: ruff-format # Type-checking python code. - repo: https://github.com/pre-commit/mirrors-mypy @@ -23,7 +23,7 @@ repos: "types-psycopg2", "types-pika" ] - files: waveform_controller/ + files: src/ # ---------- # Formats docstrings to comply with PEP257 - repo: https://github.com/PyCQA/docformatter diff --git a/pyproject.toml b/pyproject.toml index fc9c7bb..24d837c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,5 +10,8 @@ dependencies = [ "psycopg2-binary>=2.9.11", ] +[project.optional-dependencies] +dev = ["pytest>=9.0.2"] + [project.scripts] -emap-extract-waveform = "waveform_controller.controller:receiver" +emap-extract-waveform = "controller:receiver" diff --git a/waveform_controller/__init__.py b/src/__init__.py similarity index 100% rename from waveform_controller/__init__.py rename to src/__init__.py diff --git a/src/controller.py b/src/controller.py new file mode 100644 index 0000000..3c362f5 --- /dev/null +++ b/src/controller.py @@ -0,0 +1,129 @@ +""" +A script to receive messages in the waveform queue and write them to stdout, +based on https://www.rabbitmq.com/tutorials/tutorial-one-python +""" + +import json +from datetime import datetime +import threading +import queue +import logging +import pika +import db as db # type:ignore +import settings as settings # type:ignore +import csv_writer as writer # type:ignore + +max_threads = 1 # this needs to stay at 1 as pika is not thread safe. +logging.basicConfig(format="%(levelname)s:%(asctime)s: %(message)s") +logger = logging.getLogger(__name__) + + +worker_queue: queue.Queue = queue.Queue(maxsize=max_threads) + + +class waveform_message: + def __init__(self, ch, delivery_tag, body): + self.ch = ch + self.delivery_tag = delivery_tag + self.body = body + + +def ack_message(ch, delivery_tag): + """Note that `ch` must be the same pika channel instance via which the + message being ACKed was retrieved (AMQP protocol constraint).""" + if ch.is_open: + ch.basic_ack(delivery_tag) + else: + logger.warning("Attempting to acknowledge a message on a closed channel.") + + +def reject_message(ch, delivery_tag, requeue): + if ch.is_open: + ch.basic_reject(delivery_tag, requeue) + else: + logger.warning("Attempting to not acknowledge a message on a closed channel.") + + +def waveform_callback(): + emap_db = db.starDB() + emap_db.init_query() + emap_db.connect() + while True: + message = worker_queue.get() + if message is not None: + data = json.loads(message.body) + try: + location_string = data["mappedLocationString"] + observation_time = data["observationTime"] + except IndexError as e: + reject_message(message.ch, message.delivery_tag, False) + logger.error( + f"Waveform message {message.delivery_tag} is missing required data {e}." + ) + worker_queue.task_done() + continue + + observation_time = datetime.fromtimestamp(observation_time) + lookup_success = True + try: + matched_mrn = emap_db.get_row(location_string, observation_time) + except ValueError as e: + lookup_success = False + logger.error(f"Ambiguous or non existent match: {e}") + matched_mrn = ("unmatched_mrn", "unmatched_nhs", "unmatched_csn") + + if writer.write_frame(data, matched_mrn[2], matched_mrn[0]): + if lookup_success: + ack_message(message.ch, message.delivery_tag) + else: + reject_message(message.ch, message.delivery_tag, False) + + worker_queue.task_done() + else: + logger.warning("No message in queue.") + + +def on_message(ch, method_frame, _header_frame, body): + wf_message = waveform_message(ch, method_frame.delivery_tag, body) + if not worker_queue.full(): + worker_queue.put(wf_message) + else: + logger.warning("Working queue is full.") + reject_message(ch, method_frame.delivery_tag, True) + + +def receiver(): + # set up database connection + rabbitmq_credentials = pika.PlainCredentials( + username=settings.RABBITMQ_USERNAME, password=settings.RABBITMQ_PASSWORD + ) + connection_parameters = pika.ConnectionParameters( + credentials=rabbitmq_credentials, + host=settings.RABBITMQ_HOST, + port=settings.RABBITMQ_PORT, + ) + connection = pika.BlockingConnection(connection_parameters) + channel = connection.channel() + channel.basic_qos(prefetch_count=1) + + threads = [] + # I just want on thread, but in theory this should work for more + worker_thread = threading.Thread(target=waveform_callback) + worker_thread.start() + threads.append(worker_thread) + + channel.basic_consume( + queue=settings.RABBITMQ_QUEUE, + auto_ack=False, + on_message_callback=on_message, + ) + try: + channel.start_consuming() + except KeyboardInterrupt: + channel.stop_consuming() + + # Wait for all to complete + for thread in threads: + thread.join() + + connection.close() diff --git a/waveform_controller/csv_writer.py b/src/csv_writer.py similarity index 77% rename from waveform_controller/csv_writer.py rename to src/csv_writer.py index 3b8a9b1..d99a9e1 100644 --- a/waveform_controller/csv_writer.py +++ b/src/csv_writer.py @@ -6,12 +6,14 @@ def create_file_name( - sourceSystem: str, observationTime: datetime, csn: str, units: str + sourceStreamId: str, observationTime: datetime, csn: str, units: str ) -> str: """Create a unique file name based on the patient contact serial number (csn) the date, and the source system.""" datestring = observationTime.strftime("%Y-%m-%d") - return f"{datestring}.{csn}.{sourceSystem}.{units}.csv" + units = units.replace("/", "p") + units = units.replace("%", "percent") + return f"{datestring}.{csn}.{sourceStreamId}.{units}.csv" def write_frame(waveform_message: dict, csn: str, mrn: str) -> bool: @@ -20,7 +22,7 @@ def write_frame(waveform_message: dict, csn: str, mrn: str) -> bool: :return: True if write was successful. """ - sourceSystem = waveform_message.get("sourceSystem", None) + sourceStreamId = waveform_message.get("sourceStreamId", None) observationTime = waveform_message.get("observationTime", False) if not observationTime: @@ -33,7 +35,7 @@ def write_frame(waveform_message: dict, csn: str, mrn: str) -> bool: Path(out_path).mkdir(exist_ok=True) filename = out_path + create_file_name( - sourceSystem, observation_datetime, csn, units + sourceStreamId, observation_datetime, csn, units ) with open(filename, "a") as fileout: wv_writer = csv.writer(fileout, delimiter=",") @@ -45,9 +47,11 @@ def write_frame(waveform_message: dict, csn: str, mrn: str) -> bool: [ csn, mrn, + sourceStreamId, units, waveform_message.get("samplingRate", ""), observationTime, + waveform_message.get("mappedLocationString", ""), waveform_data, ] ) diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..396f442 --- /dev/null +++ b/src/db.py @@ -0,0 +1,53 @@ +from datetime import datetime +import psycopg2 +from psycopg2 import sql, pool +import logging + +import settings as settings # type:ignore + +logging.basicConfig(format="%(levelname)s:%(asctime)s: %(message)s") +logger = logging.getLogger(__name__) + + +class starDB: + sql_query: str = "" + connection_string: str = "dbname={} user={} password={} host={} port={}".format( + settings.UDS_DBNAME, # type:ignore + settings.UDS_USERNAME, # type:ignore + settings.UDS_PASSWORD, # type:ignore + settings.UDS_HOST, # type:ignore + settings.UDS_PORT, # type:ignore + ) + connection_pool: pool.ThreadedConnectionPool + + def connect(self): + self.connection_pool = pool.ThreadedConnectionPool(1, 1, self.connection_string) + + def init_query(self): + with open("src/sql/mrn_based_on_bed_and_datetime.sql", "r") as file: + self.sql_query = sql.SQL(file.read()) + self.sql_query = self.sql_query.format( + schema_name=sql.Identifier(settings.SCHEMA_NAME) + ) + + def get_row(self, location_string: str, observation_datetime: datetime): + parameters = { + "location_string": location_string, + "observation_datetime": observation_datetime, + } + try: + with self.connection_pool.getconn() as db_connection: + with db_connection.cursor() as curs: + curs.execute(self.sql_query, parameters) + rows = curs.fetchall() + self.connection_pool.putconn(db_connection) + except psycopg2.errors.UndefinedTable as e: + self.connection_pool.putconn(db_connection) + raise ConnectionError(f"Missing tables in database: {e}") + + if len(rows) != 1: + raise ValueError( + f"Wrong number of rows returned from database. {len(rows)} != 1, for {location_string}:{observation_datetime}" + ) + + return rows[0] diff --git a/waveform_controller/settings.py b/src/settings.py similarity index 100% rename from waveform_controller/settings.py rename to src/settings.py diff --git a/waveform_controller/sql/mrn_based_on_bed_and_datetime.sql b/src/sql/mrn_based_on_bed_and_datetime.sql similarity index 80% rename from waveform_controller/sql/mrn_based_on_bed_and_datetime.sql rename to src/sql/mrn_based_on_bed_and_datetime.sql index 2f998f4..e473be1 100644 --- a/waveform_controller/sql/mrn_based_on_bed_and_datetime.sql +++ b/src/sql/mrn_based_on_bed_and_datetime.sql @@ -14,5 +14,5 @@ INNER JOIN {schema_name}.location_visit lv INNER JOIN {schema_name}.location loc ON lv.location_id = loc.location_id WHERE loc.location_string = %(location_string)s - AND hv.valid_from BETWEEN %(start_datetime)s AND %(end_datetime)s -ORDER by hv.valid_from DESC + AND lv.admission_datetime <= %(observation_datetime)s + AND ( lv.discharge_datetime >= %(observation_datetime)s OR lv.discharge_datetime IS NULL ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_file_writer.py b/tests/test_file_writer.py new file mode 100644 index 0000000..112fda7 --- /dev/null +++ b/tests/test_file_writer.py @@ -0,0 +1,25 @@ +import pytest +from src.csv_writer import create_file_name +from datetime import datetime, timezone + + +@pytest.mark.parametrize( + "units, expected_filename", + [ + ("uV", "2025-01-01.12345678.11.uV.csv"), + ("mL/s", "2025-01-01.12345678.11.mLps.csv"), + ("%", "2025-01-01.12345678.11.percent.csv"), + ], +) +def test_create_file_name_handles_units(units, expected_filename, tmp_path): + sourceSystem = "11" + observationTime = datetime(2025, 1, 1, tzinfo=timezone.utc) + csn = "12345678" + + filename = create_file_name(sourceSystem, observationTime, csn, units) + + assert filename == expected_filename + + # check we can write to it + with open(f"{tmp_path}/{filename}", "w") as fileout: + fileout.write("Test string") diff --git a/uv.lock b/uv.lock index a953717..37c1376 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, ] +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -38,6 +47,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -47,6 +65,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + [[package]] name = "pika" version = "1.3.2" @@ -65,6 +92,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pre-commit" version = "4.5.0" @@ -133,6 +169,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/36/9c0c326fe3a4227953dfb29f5d0c8ae3b8eb8c1cd2967aa569f50cb3c61f/psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316", size = 2803913, upload-time = "2025-10-10T11:13:57.058Z" }, ] +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -212,9 +273,16 @@ dependencies = [ { name = "psycopg2-binary" }, ] +[package.optional-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "pika", specifier = ">=1.3.2" }, { name = "pre-commit", specifier = ">=4.5.0" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" }, ] +provides-extras = ["dev"] diff --git a/waveform_controller.py b/waveform_controller.py deleted file mode 100644 index 5b5eacb..0000000 --- a/waveform_controller.py +++ /dev/null @@ -1,4 +0,0 @@ -import waveform_controller.controller as controller - -if __name__ == "__main__": - controller.receiver() diff --git a/waveform_controller/controller.py b/waveform_controller/controller.py deleted file mode 100644 index e5c73dc..0000000 --- a/waveform_controller/controller.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -A script to receive messages in the waveform queue and write them to stdout, -based on https://www.rabbitmq.com/tutorials/tutorial-one-python -""" - -import pika -import waveform_controller.db as db -import waveform_controller.settings as settings - - -def receiver(): - # set up database connection - emap_db = db.starDB() - emap_db.init_query() - - rabbitmq_credentials = pika.PlainCredentials( - username=settings.RABBITMQ_USERNAME, password=settings.RABBITMQ_PASSWORD - ) - connection_parameters = pika.ConnectionParameters( - credentials=rabbitmq_credentials, - host=settings.RABBITMQ_HOST, - port=settings.RABBITMQ_PORT, - ) - connection = pika.BlockingConnection(connection_parameters) - channel = connection.channel() - channel.basic_consume( - queue=settings.RABBITMQ_QUEUE, - auto_ack=False, - on_message_callback=emap_db.waveform_callback, - ) - channel.start_consuming() diff --git a/waveform_controller/db.py b/waveform_controller/db.py deleted file mode 100644 index c03f6f7..0000000 --- a/waveform_controller/db.py +++ /dev/null @@ -1,59 +0,0 @@ -import psycopg2 -from psycopg2 import sql -import json -from datetime import datetime, timedelta - -import waveform_controller.settings as settings -import waveform_controller.csv_writer as writer - - -class starDB: - sql_query: str = "" - connection_string: str = "dbname={} user={} password={} host={} port={}".format( - settings.UDS_DBNAME, # type:ignore - settings.UDS_USERNAME, # type:ignore - settings.UDS_PASSWORD, # type:ignore - settings.UDS_HOST, # type:ignore - settings.UDS_PORT, # type:ignore - ) - - def init_query(self): - with open( - "waveform_controller/sql/mrn_based_on_bed_and_datetime.sql", "r" - ) as file: - self.sql_query = sql.SQL(file.read()) - self.sql_query = self.sql_query.format( - schema_name=sql.Identifier(settings.SCHEMA_NAME) - ) - - def get_row(self, location_string: str, start_datetime: str, end_datetime: str): - parameters = { - "location_string": location_string, - "start_datetime": start_datetime, - "end_datetime": end_datetime, - } - try: - with psycopg2.connect(self.connection_string) as db_connection: - with db_connection.cursor() as curs: - curs.execute(self.sql_query, parameters) - single_row = curs.fetchone() - except psycopg2.errors.UndefinedTable: - raise ConnectionError("There is no table in your data base") - - return single_row - - def waveform_callback(self, ch, method, properties, body): - data = json.loads(body) - location_string = data.get("mappedLocationString", "unknown") - observation_time = data.get("observationTime", "NaT") - observation_time = datetime.fromtimestamp(observation_time) - # I found in testing that to find the first patient I had to go back 7 months. I'm not sure this - # is expected, but I suppose an ICU patient could occupy a bed for a long time. Let's use - # 52 weeks for now. - start_time = observation_time - timedelta(weeks=52) - obs_time_str = observation_time.strftime("%Y-%m-%d:%H:%M:%S") - start_time_str = start_time.strftime("%Y-%m-%d:%H:%M:%S") - matched_mrn = self.get_row(location_string, start_time_str, obs_time_str) - - if writer.write_frame(data, matched_mrn[2], matched_mrn[0]): - ch.basic_ack(method.delivery_tag)