From 591971d2f63e20b71f383c22bc6e14f918cdbc1c Mon Sep 17 00:00:00 2001 From: Jun Wu Wang Date: Fri, 25 Apr 2025 14:59:50 +0200 Subject: [PATCH 1/5] Refactor MQTT client --- src/server/communication/mqtt_client.py | 143 ++++++++++ src/server/communication/request_handler.py | 1 - src/server/edge/run_edge.py | 18 +- src/server/mqtt_client/__init__.py | 0 src/server/mqtt_client/mqtt_client.py | 267 ------------------ src/server/mqtt_client/mqtt_configs.py | 37 --- src/server/mqtt_client/mqtt_custom_message.py | 129 --------- src/server/settings.yaml | 13 +- 8 files changed, 164 insertions(+), 444 deletions(-) create mode 100644 src/server/communication/mqtt_client.py delete mode 100644 src/server/mqtt_client/__init__.py delete mode 100644 src/server/mqtt_client/mqtt_client.py delete mode 100644 src/server/mqtt_client/mqtt_configs.py delete mode 100644 src/server/mqtt_client/mqtt_custom_message.py diff --git a/src/server/communication/mqtt_client.py b/src/server/communication/mqtt_client.py new file mode 100644 index 0000000..3df1700 --- /dev/null +++ b/src/server/communication/mqtt_client.py @@ -0,0 +1,143 @@ +import json +import random +import time + +import ntplib +import paho.mqtt.client as mqtt +import threading +import queue + +from server.logger.log import logger +from server.communication.request_handler import RequestHandler + + +class MqttClient: + def __init__( + self, + broker_url: str, + broker_port: int, + client_id: str, + protocol: str, + subscribed_topics: list, + ntp_server: str, + last_offloading_layer: int, + request_handler: RequestHandler + ): + self.broker_url = broker_url + self.broker_port = broker_port + self.client_id = client_id + + # Create the client with the specific MQTT protocol version + self.client = mqtt.Client(client_id=client_id, protocol=protocol) + + # Attach callbacks + self.client.on_connect = self.on_connect + self.client.on_message = self.on_message + + # Set up topics + self.subscribed_topics = subscribed_topics + + # Set up NTP client + self.ntp_client = ntplib.NTPClient() + self.ntp_server = ntp_server + self.offset = self.sync_with_ntp() + self.start_timestamp = self.get_current_time() + + self.request_handler = request_handler + self.best_offloading_layer = last_offloading_layer + + # Set up helper thread + self.task_queue = queue.Queue() + self.thread = threading.Thread(target=self._worker, daemon=True) + self.thread.start() + + @staticmethod + def create_random_payload(): + """Creates a random payload for testing.""" + message = json.dumps({"id": random.randint(1, 1000)}) + return message + + def publish(self, topic: str, message: str, qos: int = 2): + """Publishes a message to a topic.""" + logger.debug(f"Publishing message to {topic}: {message}") + try: + self.client.publish(topic, message, qos=qos, retain=False) + except Exception as e: + logger.debug(f"Error publishing message: {e}") + + def subscribe(self, topic: str): + """Subscribes to a topic.""" + logger.debug(f"Subscribing to topic: {topic}") + self.client.subscribe(topic) + + def run(self): + """Connect to the broker and start the MQTT client loop.""" + self.client.connect(self.broker_url, self.broker_port, 60) + self.client.loop_forever() + + def stop(self): + """Stops the MQTT client loop and disconnects.""" + logger.debug("Disconnecting MQTT client") + self.client.disconnect() + + def on_connect(self, client, userdata, flags, rc): + if rc == 0: + logger.debug(f"Connected to {self.broker_url}:{self.broker_port} with client ID {self.client_id}") + for topic in self.subscribed_topics.values(): + self.subscribe(topic) + logger.debug(f"Initial NTP timestamp from NTP server {self.ntp_server}: {self.start_timestamp}") + else: + logger.debug(f"Connection failed with code {rc}") + + def sync_with_ntp(self) -> float: + ntp_timestamp = None + while ntp_timestamp is None: + try: + response = self.ntp_client.request(self.ntp_server) + # Get the offset between local clock time and ntp server time (seconds since 1900) + offset = response.offset + logger.debug(f"Synchronized with NTP server. Offset: {offset} seconds") + return offset + except ntplib.NTPException as _: + time.sleep(1) + threading.Timer(600, self.sync_with_ntp).start() + + def get_current_time(self) -> float: + return time.time() + self.offset + + def _worker(self): + while True: + task = self.task_queue.get() + task() + self.task_queue.task_done() + + def on_message(self, client, userdata, message): + received_timestamp = self.get_current_time() + + def task(): + self.handle_message_task(message, received_timestamp) + + self.task_queue.put(task) # Submit the task to the worker thread queue + + def handle_message_task(self, message, received_timestamp): + if message.topic == self.subscribed_topics['device_inference_result']: # Updates device time, runs offloading algorithm and sends best offloading layer + logger.debug('Device inference result received') + self.best_offloading_layer = self.request_handler.handle_device_inference_result(body=message.payload, received_timestamp=received_timestamp) + cleaned_offloading_layer_index = self.request_handler.handle_offloading_layer(best_offloading_layer=self.best_offloading_layer) + message_data = {'offloading_layer_index': cleaned_offloading_layer_index} + self.publish(self.subscribed_topics['offloading_layer'], json.dumps(message_data)) + logger.debug('Best offloading layer sent') + elif message.topic == self.subscribed_topics['device_input']: # Save input image + logger.debug('Device input received') + self.request_handler.handle_device_input(message.payload) + logger.debug('Device input saved') + elif message.topic == self.subscribed_topics['registration']: # Sends best offloading layer + logger.debug('Registration request received') + decoded_payload = message.payload.decode() + json_data = json.loads(decoded_payload) + cleaned_device_id = self.request_handler.handle_registration(json_data["device_id"]) + cleaned_offloading_layer_index = self.request_handler.handle_offloading_layer(best_offloading_layer=self.best_offloading_layer) + message_data = {'offloading_layer_index': cleaned_offloading_layer_index} + self.publish(self.subscribed_topics['offloading_layer'], json.dumps(message_data)) + logger.debug('Best offloading layer sent') + diff --git a/src/server/communication/request_handler.py b/src/server/communication/request_handler.py index 9901b91..0e34827 100644 --- a/src/server/communication/request_handler.py +++ b/src/server/communication/request_handler.py @@ -23,7 +23,6 @@ def handle_device_input(self, rgb565_image): image_array = InputData.make_array(rgb565_image=rgb565_image) image = Image.fromarray(image_array, 'RGB') image.save(InputDataFiles.input_data_file_path) - logger.debug("Input image saved") return def handle_device_inference_result(self, body, received_timestamp): diff --git a/src/server/edge/run_edge.py b/src/server/edge/run_edge.py index f24a8c7..f2b4de0 100644 --- a/src/server/edge/run_edge.py +++ b/src/server/edge/run_edge.py @@ -1,11 +1,11 @@ from server.edge.edge_initialization import Edge from server.logger.log import logger -from server.mqtt_client.mqtt_client import MqttClient -from server.mqtt_client.mqtt_configs import MqttClientConfig import yaml from server.communication.websocket_server import WebsocketServer from server.communication.http_server import HttpServer +from server.communication.mqtt_client import MqttClient +from paho.mqtt import client as mqtt from server.communication.request_handler import RequestHandler from server.commons import ConfigurationFiles @@ -44,11 +44,15 @@ http_server.run() if 'mqtt' in config['communication']['mode']: + mqtt_config = config['communication']['mqtt'] mqtt_client = MqttClient( - broker_url=MqttClientConfig.broker_url, - broker_port=MqttClientConfig.broker_port, - client_id=MqttClientConfig.client_id, - protocol=MqttClientConfig.protocol, - subscribed_topics=MqttClientConfig.subscribe_topics + broker_url=mqtt_config['broker_url'], + broker_port=mqtt_config['broker_port'], + client_id=mqtt_config['client_id'], + protocol=mqtt.MQTTv311, + subscribed_topics=mqtt_config['topics'], + ntp_server=http_config['ntp_server'], + last_offloading_layer=mqtt_config['last_offloading_layer'], + request_handler=RequestHandler() ) mqtt_client.run() diff --git a/src/server/mqtt_client/__init__.py b/src/server/mqtt_client/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/server/mqtt_client/mqtt_client.py b/src/server/mqtt_client/mqtt_client.py deleted file mode 100644 index 636efe1..0000000 --- a/src/server/mqtt_client/mqtt_client.py +++ /dev/null @@ -1,267 +0,0 @@ -import json -import random -import time - -import ntplib -import numpy as np -import paho.mqtt.client as mqtt -from server.edge.edge_initialization import Edge -from server.device.device_initialization import Device -from server.offloading_algo.offloading_algo import OffloadingAlgo -from PIL import Image -import threading -import queue - -from server.commons import OffloadingDataFiles -from server.commons import EvaluationFiles -from server.commons import InputData -from server.commons import InputDataFiles -from server.logger.log import logger -from server.mqtt_client.mqtt_configs import MqttClientConfig, Topics, DefaultMessages -from server.mqtt_client.mqtt_custom_message import MqttMessageData - - -class MqttClient: - def __init__( - self, - broker_url: str = MqttClientConfig.broker_url, - broker_port: int = MqttClientConfig.broker_port, - client_id: str = MqttClientConfig.client_id, - protocol: str = MqttClientConfig.protocol, - subscribed_topics: list = None, - ntp_server: str = MqttClientConfig.ntp_server - ): - self.broker_url = broker_url - self.broker_port = broker_port - self.client_id = client_id - - # Create the client with the specific MQTT protocol version - self.client = mqtt.Client(client_id=client_id, protocol=protocol) - - # Attach callbacks - self.client.on_connect = self.on_connect - self.client.on_message = self.on_message - - # Set up topics - self.subscribed_topics = subscribed_topics - - # Set up NTP client - self.ntp_client = ntplib.NTPClient() - self.ntp_server = ntp_server - self.offset = self.sync_with_ntp() - self.start_timestamp = self.get_current_time() - - # Stats - self.layers_sizes = [] - self.edge_inference_times = [] - self.device_inference_times = [] - self.load_stats() - - # Set up helper thread - self.task_queue = queue.Queue() - self.thread = threading.Thread(target=self._worker, daemon=True) - self.thread.start() - - @staticmethod - def create_random_payload(): - """Creates a random payload for testing.""" - message = json.dumps({"id": random.randint(1, 1000)}) - return message - - def publish(self, topic: str, message: str, qos: int = 2): - """Publishes a message to a topic.""" - logger.debug(f"Publishing message to {topic}: {message}") - try: - self.client.publish(topic, message, qos=qos, retain=False) - except Exception as e: - logger.debug(f"Error publishing message: {e}") - - def subscribe(self, topic: str): - """Subscribes to a topic.""" - logger.debug(f"Subscribing to topic: {topic}") - self.client.subscribe(topic) - - def run(self): - """Connect to the broker and start the MQTT client loop.""" - self.client.connect(self.broker_url, self.broker_port, 60) - self.client.loop_forever() - - def stop(self): - """Stops the MQTT client loop and disconnects.""" - logger.debug("Disconnecting MQTT client") - self.client.disconnect() - - def on_connect(self, client, userdata, flags, rc): - if rc == 0: - logger.debug(f"Connected to {self.broker_url}:{self.broker_port} with client ID {self.client_id}") - for topic in self.subscribed_topics: - self.subscribe(topic) - logger.debug(f"Initial NTP timestamp from NTP server {self.ntp_server}: {self.start_timestamp}") - else: - logger.debug(f"Connection failed with code {rc}") - - def sync_with_ntp(self) -> float: - ntp_timestamp = None - while ntp_timestamp is None: - try: - response = self.ntp_client.request(self.ntp_server) - # Get the offset between local clock time and ntp server time (seconds since 1900) - offset = response.offset - logger.debug(f"Synchronized with NTP server. Offset: {offset} seconds") - return offset - except ntplib.NTPException as _: - time.sleep(1) - threading.Timer(600, self.sync_with_ntp).start() - - def get_current_time(self) -> float: - return time.time() + self.offset - - def _worker(self): - while True: - task = self.task_queue.get() - task() - self.task_queue.task_done() - - def on_message(self, client, userdata, message): - received_timestamp = self.get_current_time() - - def task(): - self.handle_message_task(message, received_timestamp) - - self.task_queue.put(task) # Submit the task to the worker thread queue - - def handle_message_task(self, message, received_timestamp): - # Save input image - if message.topic == Topics.device_input.value: - image_array = InputData.make_array(message.payload) - image = Image.fromarray(image_array, 'RGB') - image.save(InputDataFiles.input_data_file_path) - message_data = MqttMessageData( - topic=message.topic, - payload="InputImage", - device_id="device_01", - message_id=None, - message_content="InputImage", - timestamp=None, - ) - MqttMessageData.save_to_file(EvaluationFiles.evaluation_file_path, message_data.to_dict()) - logger.debug("Input image saved") - return - - # obtain message data if the message is JSON valid - try: - message_data = MqttMessageData.from_raw(message.topic, message.payload) - except json.JSONDecodeError: - logger.error(f"Failed to decode JSON from payload on topic {message.topic}") - return - - # check if the message is valid - sent after the edge mqtt client is started - if float(message_data.timestamp) <= float(self.start_timestamp): - return - logger.debug(f"Received a valid message") - - # Extend message data - message_data = self.extend_message_data(message_data, received_timestamp, message.payload) - # Save message data to file - MqttMessageData.save_to_file(EvaluationFiles.evaluation_file_path, message_data.to_dict()) - - # run offloading algorithm and ask for prediction after the device sends the registration message - if message_data.topic == Topics.registration.value: - Device.initialization() - # run offloading algorithm - offloading_algo = OffloadingAlgo( - avg_speed=message_data.avg_speed, - num_layers=len(self.layers_sizes), - layers_sizes=list(self.layers_sizes), - inference_time_device=list(self.device_inference_times), - inference_time_edge=list(self.edge_inference_times) - ) - best_offloading_layer = offloading_algo.static_offloading() - # send best offloading layer index - self.send_offloading_layer_index( - ask_device_id=message_data.device_id, - message_id=message_data.message_id, - best_offloading_layer=best_offloading_layer, - ) - - # ends the computation after receiving the inference result - if message_data.topic == Topics.device_inference_result.value: - # update device inference time - with open(OffloadingDataFiles.data_file_path_device, 'r') as f: - device_inference_times = json.load(f) - for l_id, inference_time in enumerate(message_data.device_layers_inference_time): - device_inference_times[f"layer_{l_id}"] = inference_time - with open(OffloadingDataFiles.data_file_path_device, 'w') as f: - json.dump(device_inference_times, f, indent=4) - # finish inference - prediction = Edge.run_inference(message_data.offloading_layer_index, np.array(message_data.layer_output, dtype=np.float32)) - logger.debug(f"Prediction: {prediction.tolist()}") - MqttMessageData.save_to_file(EvaluationFiles.web_file_path, message_data.to_dict()) - # run offloading algorithm - offloading_algo = OffloadingAlgo( - avg_speed=message_data.avg_speed, - num_layers=len(self.layers_sizes), - layers_sizes=list(self.layers_sizes), - inference_time_device=list(self.device_inference_times), - inference_time_edge=list(self.edge_inference_times) - ) - best_offloading_layer = offloading_algo.static_offloading() - # send best offloading layer index - self.send_offloading_layer_index( - ask_device_id=message_data.device_id, - message_id=message_data.message_id, - best_offloading_layer=best_offloading_layer, - ) - - def send_offloading_layer_index(self, ask_device_id, message_id, best_offloading_layer: int): - logger.debug(f"Sending offloading layer index to {ask_device_id}") - message_data = DefaultMessages.offloading_layer_msg - message_data["timestamp"] = self.get_current_time() - message_data['message_id'] = message_id - message_data['offloading_layer_index'] = best_offloading_layer - self.publish(Topics.offloading_layer.value, json.dumps(message_data)) - - def load_stats(self): - """ Loads the offloading stats from the JSON files """ - with open(OffloadingDataFiles.data_file_path_device, 'r') as file: - self.device_inference_times = json.load(file) - self.device_inference_times = list({k: v for k, v in self.device_inference_times.items()}.values()) - - with open(OffloadingDataFiles.data_file_path_edge, 'r') as file: - self.edge_inference_times = json.load(file) - self.edge_inference_times = list({k: v for k, v in self.edge_inference_times.items()}.values()) - - with open(OffloadingDataFiles.data_file_path_sizes, 'r') as file: - self.layers_sizes = json.load(file) - self.layers_sizes = list({k: v for k, v in self.layers_sizes.items()}.values()) - - logger.debug(f"Loaded stats data") - - @staticmethod - def extend_message_data(message_data: MqttMessageData, received_timestamp: float, payload) -> MqttMessageData: - """Extend the message data with additional information. - - Args: - message_data (MqttMessageData): The message data to extend. - received_timestamp (float): The timestamp of the message reception. - - Returns: - MqttMessageData: The extended message data. - """ - # update stats info - message_data.received_timestamp = received_timestamp - message_data.payload_size = MqttMessageData.get_bytes_size(payload) - message_data.synthetic_latency = MqttMessageData.get_synthetic_latency() - message_data.latency = MqttMessageData.get_latency(message_data.timestamp, message_data.received_timestamp) - message_data.avg_speed = MqttMessageData.get_avg_speed( - message_data.payload_size, - message_data.latency, - message_data.synthetic_latency - ) - # update offloading info - ( - message_data.offloading_layer_index, - message_data.layer_output, - message_data.device_layers_inference_time - ) = MqttMessageData.get_offloading_info(message_data.message_content) - return message_data diff --git a/src/server/mqtt_client/mqtt_configs.py b/src/server/mqtt_client/mqtt_configs.py deleted file mode 100644 index 86df804..0000000 --- a/src/server/mqtt_client/mqtt_configs.py +++ /dev/null @@ -1,37 +0,0 @@ -import enum -from dataclasses import dataclass - -from paho.mqtt import client as mqtt - - -class Topics(enum.Enum): - registration = "devices/" - offloading_layer = "device_01/offloading_layer" - device_input = "device_01/input_data" - device_inference_result = "device_01/model_inference_result" - - -@dataclass -class MqttClientConfig: - broker_url: str = "hostname.local" - broker_port: int = 1883 - client_id: str = "edge" - subscribe_topics: list = ( - Topics.registration.value, - Topics.offloading_layer.value, - Topics.device_input.value, - Topics.device_inference_result.value - ) - ntp_server: str = "0.it.pool.ntp.org" - protocol: mqtt.MQTTv311 = mqtt.MQTTv311 - - -@dataclass -class DefaultMessages: - offloading_layer_msg = { - "device_id": "edge", - "message_id": "edge", - "timestamp": None, - "message_content": "OffloadingLayer", - "offloading_layer_index": None, - } \ No newline at end of file diff --git a/src/server/mqtt_client/mqtt_custom_message.py b/src/server/mqtt_client/mqtt_custom_message.py deleted file mode 100644 index 3d48203..0000000 --- a/src/server/mqtt_client/mqtt_custom_message.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -import os -from dataclasses import dataclass - -import pandas as pd -import struct - -from server.logger.log import logger -from server.mqtt_client.mqtt_configs import Topics - - -@dataclass -class MqttMessageData: - topic: str - payload: str - device_id: int - message_id: int - message_content: str - timestamp: str - - received_timestamp = None - avg_speed = None - latency = None - synthetic_latency = None - payload_size = None - offloading_layer_index = None - layer_output = None - device_layers_inference_time = None - - @staticmethod - def from_raw(topic: str, payload: bytes): - """Parse the raw message payload into a MessageDataInput instance.""" - try: - if topic == Topics.device_inference_result.value: - message_data = {} - message_content = {} - - # decode the payload from bytes to values and parse as JSON - message_data["timestamp"] = struct.unpack('d', payload[:8])[0] - offset = 8 - message_data["device_id"] = payload[offset:offset+9].decode() - offset += 9 - message_data["message_id"] = payload[offset:offset+4].decode() - offset += 4 - message_content["offloading_layer_index"] = struct.unpack('i', payload[offset:offset+4])[0] - offset += 4 - layer_output_size = struct.unpack('I', payload[offset:offset+4])[0] - offset += 4 - message_content["layer_output"] = struct.unpack(f'<{int(layer_output_size/4)}f', payload[offset:offset+layer_output_size]) - offset += layer_output_size - layers_inference_time_size = struct.unpack('i', payload[offset:offset+4])[0] - offset += 4 - message_content["layers_inference_time"] = struct.unpack(f'<{int(layers_inference_time_size/4)}f', payload[offset:offset+layers_inference_time_size]) - message_data["message_content"] = message_content - - decoded_payload = json.dumps(message_data) - else: - # decode the payload from bytes to string and parse as JSON - decoded_payload = payload.decode() - message_data = json.loads(decoded_payload) - - message_content = message_data["message_content"] - # return an instance of MessageDataInput with extracted fields - return MqttMessageData( - topic=topic, - payload=decoded_payload, - device_id=message_data["device_id"], - message_id=message_data["message_id"], - message_content=message_content, - timestamp=message_data["timestamp"], - ) - except json.JSONDecodeError: - # handles payload that cannot be parsed as JSON - raise - - def to_dict(self): - return self.__dict__ - - @staticmethod - def save_to_file(file_path: str, data_dict: dict): - # check if the file already exists - file_exists = os.path.isfile(file_path) - try: - # create a DataFrame from the data dictionary - df = pd.DataFrame.from_dict([data_dict]) - # append to the CSV file; write header only if file does not exist - df.to_csv(file_path, mode='a', header=not file_exists, index=False) - logger.debug(f"Data saved to {file_path}") - except Exception as e: - logger.error(f"Failed to save data to {file_path}: {e}") - - @staticmethod - def get_latency(timestamp: str, received_timestamp: str) -> tuple[float, dict]: - # NTP timestamps as strings (representing seconds since 1900) - # convert the NTP timestamps from string to float - ntp_timestamp_1 = float(timestamp) - ntp_timestamp_2 = float(received_timestamp) - # calculate the duration between the two NTP timestamps - duration_seconds = ntp_timestamp_2 - ntp_timestamp_1 - # convert the duration to a readable format - return duration_seconds - - @staticmethod - def get_bytes_size(payload) -> int: - return len(payload) - - @staticmethod - def get_synthetic_latency() -> float: - return 1 - - @staticmethod - def get_avg_speed(payload_size: float, latency: float, synthetic_latency: float) -> float: - message_latency = latency * synthetic_latency - try: - avg_speed = payload_size / message_latency - except ZeroDivisionError: - avg_speed = 0 - return avg_speed - - @staticmethod - def get_offloading_info(message_content: dict) -> tuple: - # check if layer_output and offloading_layer_index exist in message_content - try: - layer_output = message_content.get("layer_output", None) - offloading_layer_index = message_content.get("offloading_layer_index", None) - device_layers_inference_time = message_content.get("layers_inference_time", None) - return offloading_layer_index, layer_output, device_layers_inference_time - except Exception as _: - return None, None, None diff --git a/src/server/settings.yaml b/src/server/settings.yaml index 8ccff81..d3bc7ff 100644 --- a/src/server/settings.yaml +++ b/src/server/settings.yaml @@ -20,6 +20,13 @@ communication: ntp_server: "0.it.pool.ntp.org" last_offloading_layer: 58 mqtt: - broker: "hostname.local" - port: 1883 - + broker_url: "hostname.local" + broker_port: 1883 + client_id: "edge" + topics: + registration: "devices/" + offloading_layer: "device_01/offloading_layer" + device_input: "device_01/input_data" + device_inference_result: "device_01/model_inference_result" + ntp_server: "0.it.pool.ntp.org" + last_offloading_layer: 58 From 0d43cf2e5a8feb0d264844c1d77f4174d1b58c2b Mon Sep 17 00:00:00 2001 From: Jun Wu Wang Date: Mon, 28 Apr 2025 21:56:39 +0200 Subject: [PATCH 2/5] Refactor InputData --- src/server/commons.py | 38 +----------------- src/server/communication/http_server.py | 16 ++++++-- src/server/communication/mqtt_client.py | 11 ++++- src/server/communication/request_handler.py | 13 +++--- src/server/communication/websocket_server.py | 11 ++++- src/server/edge/edge_initialization.py | 7 ++-- src/server/edge/run_edge.py | 21 +++++++--- src/server/models/model_input_converter.py | 31 ++++++++++++++ .../{input_data.png => test_image.png} | Bin src/server/settings.yaml | 12 ++++-- 10 files changed, 98 insertions(+), 62 deletions(-) create mode 100644 src/server/models/model_input_converter.py rename src/server/models/test/test_model/pred_data/{input_data.png => test_image.png} (100%) diff --git a/src/server/commons.py b/src/server/commons.py index 2038006..a5ee8cf 100644 --- a/src/server/commons.py +++ b/src/server/commons.py @@ -1,8 +1,4 @@ from pathlib import Path -import struct - -import numpy as np -from tensorflow.keras.preprocessing.image import load_img, img_to_array BASE_DIR = Path(__file__).resolve().parent @@ -16,7 +12,6 @@ class OffloadingDataFiles: class EvaluationFiles: evaluation_file_path: str = str(BASE_DIR / "evaluations/evaluations.csv") - web_file_path: str = str(BASE_DIR / "evaluations/web.csv") class ConfigurationFiles: @@ -28,36 +23,5 @@ class ModelFiles: class InputDataFiles: - test_data_file_path: str = str(BASE_DIR / "models/test/test_model/pred_data/input_data.png") # Path to test image + test_data_file_path: str = str(BASE_DIR / "models/test/test_model/pred_data/test_image.png") # Path to test image input_data_file_path: str = str(BASE_DIR / "input_data.png") # Input image save path - - -class InputData: - height = 96 - width = 96 - - def __init__(self, image_path=InputDataFiles.test_data_file_path, color_mode="rgb", - target_size=(height, width)): # Model input configuration - input_image = load_img(image_path, color_mode=color_mode, target_size=target_size) - image_array = img_to_array(input_image) - self.image_array = np.array([image_array]) - - @staticmethod - def make_array(rgb565_image, h=height, w=width): - image_array = [] - - for i in range(h): - row = [] - s = rgb565_image[i * w * 2:(i + 1) * w * 2] - pixels = struct.unpack(f'>{w}H', s) - for p in pixels: - r = p >> 11 - g = (p >> 5) & 0x3f - b = p & 0x1f - r = (r * 255) / 31.0 - g = (g * 255) / 63.0 - b = (b * 255) / 31.0 - row.append([int(round(x)) for x in [r, g, b]]) - image_array.append(row) - - return np.array(image_array, dtype=np.uint8) diff --git a/src/server/communication/http_server.py b/src/server/communication/http_server.py index f952e1e..80adb51 100644 --- a/src/server/communication/http_server.py +++ b/src/server/communication/http_server.py @@ -11,6 +11,8 @@ def __init__( port: int, endpoints: dict, ntp_server: str, + input_height: int, + input_width: int, last_offloading_layer: int, request_handler: RequestHandler ): @@ -18,11 +20,17 @@ def __init__( self.host = host self.port = port self.endpoints = endpoints - - self.request_handler = request_handler - self.best_offloading_layer = last_offloading_layer + self.devices = set() + # Set up model + self.input_height = input_height + self.input_width = input_width + self.best_offloading_layer = last_offloading_layer + + # Set up request handler + self.request_handler = request_handler + # Set up NTP client self.ntp_client = ntplib.NTPClient() self.ntp_server = ntp_server @@ -59,7 +67,7 @@ async def registration(data: dict): async def device_input(request: Request): try: body = await request.body() # Reads raw bytes - self.request_handler.handle_device_input(body) + self.request_handler.handle_device_input(body, self.input_height, self.input_width) return {'message': 'Success'} except Exception as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/src/server/communication/mqtt_client.py b/src/server/communication/mqtt_client.py index 3df1700..08d4724 100644 --- a/src/server/communication/mqtt_client.py +++ b/src/server/communication/mqtt_client.py @@ -20,6 +20,8 @@ def __init__( protocol: str, subscribed_topics: list, ntp_server: str, + input_height: int, + input_width: int, last_offloading_layer: int, request_handler: RequestHandler ): @@ -43,9 +45,14 @@ def __init__( self.offset = self.sync_with_ntp() self.start_timestamp = self.get_current_time() - self.request_handler = request_handler + # Set up model + self.input_height = input_height + self.input_width = input_width self.best_offloading_layer = last_offloading_layer + # Set up request handler + self.request_handler = request_handler + # Set up helper thread self.task_queue = queue.Queue() self.thread = threading.Thread(target=self._worker, daemon=True) @@ -129,7 +136,7 @@ def handle_message_task(self, message, received_timestamp): logger.debug('Best offloading layer sent') elif message.topic == self.subscribed_topics['device_input']: # Save input image logger.debug('Device input received') - self.request_handler.handle_device_input(message.payload) + self.request_handler.handle_device_input(message.payload, self.input_height, self.input_width) logger.debug('Device input saved') elif message.topic == self.subscribed_topics['registration']: # Sends best offloading layer logger.debug('Registration request received') diff --git a/src/server/communication/request_handler.py b/src/server/communication/request_handler.py index 0e34827..6f818ad 100644 --- a/src/server/communication/request_handler.py +++ b/src/server/communication/request_handler.py @@ -1,17 +1,20 @@ import json - import numpy as np +from PIL import Image + from server.edge.edge_initialization import Edge from server.offloading_algo.offloading_algo import OffloadingAlgo -from PIL import Image from server.commons import OffloadingDataFiles from server.commons import EvaluationFiles -from server.commons import InputData from server.commons import InputDataFiles + from server.logger.log import logger + from server.communication.message_data import MessageData +from server.models.model_input_converter import ModelInputConverter + import struct @@ -19,8 +22,8 @@ class RequestHandler(): def handle_registration(self, device_id): return device_id - def handle_device_input(self, rgb565_image): - image_array = InputData.make_array(rgb565_image=rgb565_image) + def handle_device_input(self, rgb565_image, height, width): + image_array = ModelInputConverter.convert_rgb565_to_nparray(rgb565_image, height, width) image = Image.fromarray(image_array, 'RGB') image.save(InputDataFiles.input_data_file_path) return diff --git a/src/server/communication/websocket_server.py b/src/server/communication/websocket_server.py index bf2dbc9..0549033 100644 --- a/src/server/communication/websocket_server.py +++ b/src/server/communication/websocket_server.py @@ -15,6 +15,8 @@ def __init__( port: int, endpoint: str, ntp_server: str, + input_height: int, + input_width: int, last_offloading_layer: int, request_handler: RequestHandler ): @@ -23,9 +25,14 @@ def __init__( self.port = port self.endpoint = endpoint - self.request_handler = request_handler + # Set up model + self.input_height = input_height + self.input_width = input_width self.best_offloading_layer = last_offloading_layer + # Set up request handler + self.request_handler = request_handler + # Set up NTP client self.ntp_client = ntplib.NTPClient() self.ntp_server = ntp_server @@ -74,7 +81,7 @@ async def websocket_endpoint(websocket: WebSocket): binary_data = message['bytes'] if len(binary_data) == 18432: logger.debug('Device input received') - self.request_handler.handle_device_input(binary_data) + self.request_handler.handle_device_input(binary_data, self.input_height, self.input_width) logger.debug('Device input saved') else: logger.debug('Device inference result received') diff --git a/src/server/edge/edge_initialization.py b/src/server/edge/edge_initialization.py index 7cdd7b8..899dcf8 100644 --- a/src/server/edge/edge_initialization.py +++ b/src/server/edge/edge_initialization.py @@ -3,9 +3,10 @@ import numpy as np import tensorflow as tf -from server.commons import InputData from server.commons import OffloadingDataFiles +from server.commons import InputDataFiles from server.models.model_manager import ModelManager +from server.models.model_input_converter import ModelInputConverter class Edge: @@ -68,9 +69,9 @@ def run_inference(offloading_layer_index: int, offloading_layer_output: np.array return predictions[layers_to_use[num_of_layers - start_layer_index - 1]] @staticmethod - def initialization(): + def initialization(input_height, input_width): # original array - image_array = InputData().image_array + image_array = ModelInputConverter.convert_png_to_nparray(InputDataFiles.test_data_file_path, input_height, input_width) image_array = image_array / 255.0 # Normalize pixel values # check the shape and dtype diff --git a/src/server/edge/run_edge.py b/src/server/edge/run_edge.py index f2b4de0..782a302 100644 --- a/src/server/edge/run_edge.py +++ b/src/server/edge/run_edge.py @@ -13,38 +13,45 @@ if __name__ == "__main__": logger.info("Starting the [EDGE] MQTT client") - # initialize edge inference times - Edge.initialization() - with open(ConfigurationFiles.server_configuration_file_path, "r") as f: config = yaml.safe_load(f) if 'websocket' in config['communication']['mode']: websocket_config = config['communication']['websocket'] + model_config = config['model'][websocket_config['model']] + Edge.initialization(input_height=model_config['input_height'], input_width=model_config['input_width']) # initialize edge inference times websocket_server = WebsocketServer( host=websocket_config['host'], port=websocket_config['port'], endpoint=websocket_config['endpoint'], ntp_server=websocket_config['ntp_server'], - last_offloading_layer=websocket_config['last_offloading_layer'], + input_height=model_config['input_height'], + input_width=model_config['input_width'], + last_offloading_layer=model_config['last_offloading_layer'], request_handler=RequestHandler() ) websocket_server.run() if 'http' in config['communication']['mode']: http_config = config['communication']['http'] + model_config = config['model'][http_config['model']] + Edge.initialization(input_height=model_config['input_height'], input_width=model_config['input_width']) # initialize edge inference times http_server = HttpServer( host=http_config['host'], port=http_config['port'], endpoints=http_config['endpoints'], ntp_server=http_config['ntp_server'], - last_offloading_layer=http_config['last_offloading_layer'], + input_height=model_config['input_height'], + input_width=model_config['input_width'], + last_offloading_layer=model_config['last_offloading_layer'], request_handler=RequestHandler() ) http_server.run() if 'mqtt' in config['communication']['mode']: mqtt_config = config['communication']['mqtt'] + model_config = config['model'][mqtt_config['model']] + Edge.initialization(input_height=model_config['input_height'], input_width=model_config['input_width']) # initialize edge inference times mqtt_client = MqttClient( broker_url=mqtt_config['broker_url'], broker_port=mqtt_config['broker_port'], @@ -52,7 +59,9 @@ protocol=mqtt.MQTTv311, subscribed_topics=mqtt_config['topics'], ntp_server=http_config['ntp_server'], - last_offloading_layer=mqtt_config['last_offloading_layer'], + input_height=model_config['input_height'], + input_width=model_config['input_width'], + last_offloading_layer=model_config['last_offloading_layer'], request_handler=RequestHandler() ) mqtt_client.run() diff --git a/src/server/models/model_input_converter.py b/src/server/models/model_input_converter.py new file mode 100644 index 0000000..fc080b3 --- /dev/null +++ b/src/server/models/model_input_converter.py @@ -0,0 +1,31 @@ +import struct +import numpy as np +from tensorflow.keras.preprocessing.image import load_img, img_to_array + + +class ModelInputConverter: + @staticmethod + def convert_png_to_nparray(png_image_path, height, width, color_mode="rgb"): + png_image = load_img(png_image_path, color_mode=color_mode, target_size=(height, width)) + image_array = img_to_array(png_image) + return np.array([image_array]) + + @staticmethod + def convert_rgb565_to_nparray(rgb565_image, height, width): + image_array = [] + + for i in range(height): + row = [] + s = rgb565_image[i * width * 2:(i + 1) * width * 2] + pixels = struct.unpack(f'>{width}H', s) + for p in pixels: + r = p >> 11 + g = (p >> 5) & 0x3f + b = p & 0x1f + r = (r * 255) / 31.0 + g = (g * 255) / 63.0 + b = (b * 255) / 31.0 + row.append([int(round(x)) for x in [r, g, b]]) + image_array.append(row) + + return np.array(image_array, dtype=np.uint8) \ No newline at end of file diff --git a/src/server/models/test/test_model/pred_data/input_data.png b/src/server/models/test/test_model/pred_data/test_image.png similarity index 100% rename from src/server/models/test/test_model/pred_data/input_data.png rename to src/server/models/test/test_model/pred_data/test_image.png diff --git a/src/server/settings.yaml b/src/server/settings.yaml index d3bc7ff..ed37ef2 100644 --- a/src/server/settings.yaml +++ b/src/server/settings.yaml @@ -8,7 +8,7 @@ communication: port: 8080 endpoint: "/ws" ntp_server: "0.it.pool.ntp.org" - last_offloading_layer: 58 + model: "fomo_96x96" http: host: "0.0.0.0" port: 8000 @@ -18,7 +18,7 @@ communication: device_inference_result: "/api/device_inference_result" # post offloading_layer: "/api/offloading_layer" # get ntp_server: "0.it.pool.ntp.org" - last_offloading_layer: 58 + model: "fomo_96x96" mqtt: broker_url: "hostname.local" broker_port: 1883 @@ -29,4 +29,10 @@ communication: device_input: "device_01/input_data" device_inference_result: "device_01/model_inference_result" ntp_server: "0.it.pool.ntp.org" - last_offloading_layer: 58 + model: "fomo_96x96" + +model: + fomo_96x96: + input_height: 96 + input_width: 96 + last_offloading_layer: 58 \ No newline at end of file From a78649a9e0faddf462b04feae808af0f0800ba4e Mon Sep 17 00:00:00 2001 From: Jun Wu Wang Date: Mon, 28 Apr 2025 22:02:30 +0200 Subject: [PATCH 3/5] Change typing --- src/server/web/webpage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/web/webpage.py b/src/server/web/webpage.py index efb1364..9388748 100644 --- a/src/server/web/webpage.py +++ b/src/server/web/webpage.py @@ -31,7 +31,7 @@ quadcol = st.columns(4) -quadcol[0].metric(label="Best offloading layer", value = df_last_row['offloading_layer_index']) +quadcol[0].metric(label="Best offloading layer", value = f"{int(df_last_row['offloading_layer_index'].iloc[0]):,}") quadcol[1].metric(label="Layer size", value = f"{int(df_last_row['payload_size'].iloc[0]):,} Bytes") quadcol[2].metric(label="Latency", value = f"{float(df_last_row['latency'].iloc[0]):,.4f} s") quadcol[3].metric(label="Network speed", value = f"{float(df_last_row['avg_speed'].iloc[0]):,.2f} Bytes/s") From 95822d98eb40162c4475bf43b6379733a8f9ac7e Mon Sep 17 00:00:00 2001 From: Jun Wu Wang Date: Mon, 28 Apr 2025 22:03:02 +0200 Subject: [PATCH 4/5] Update README.md --- README.md | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 769b0fd..c39a49e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,84 @@ # SCIoT project The Split Computing on IoT (SCIoT) project provides tools to use Edge Impulse models in ESP32 devices, using split computing techniques. ---- -![Unit Tests](https://github.com/UBICO/SCIoT/actions/workflows/codecov.yml/badge.svg)
-[![Coverage](https://codecov.io/github/UBICO/SCIoT//coverage.svg?branch=main)](https://codecov.io/gh/UBICO/SCIoT) -[![Powered by UBICO](https://img.shields.io/badge/powered%20by-UBICO-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)]() +![Unit Tests](https://github.com/UBICO/SCIoT/actions/workflows/codecov.yml/badge.svg) [![Coverage](https://codecov.io/github/UBICO/SCIoT//coverage.svg?branch=main)](https://codecov.io/gh/UBICO/SCIoT) [![Powered by UBICO](https://img.shields.io/badge/powered%20by-UBICO-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)]() ## Publications If you use this work, please consider citing our work: - F. Bove, S. Colli and L. Bedogni, "Performance Evaluation of Split Computing with TinyML on IoT Devices," 2024 IEEE 21st Consumer Communications & Networking Conference (CCNC), Las Vegas, NV, USA, 2024, pp. 1-6, [DOI Link](http://dx.doi.org/10.1109/CCNC51664.2024.10454775). - F. Bove and L. Bedogni, "Smart Split: Leveraging TinyML and Split Computing for Efficient Edge AI," 2024 IEEE/ACM Symposium on Edge Computing (SEC), Rome, Italy, 2024, pp. 456-460, [DOI Link](http://dx.doi.org/10.1109/SEC62691.2024.00052). + +## Configuration +Clone and go in the repository: + +```sh +git clone https://github.com/UBICO/SCIoT.git +``` + +Install python 3.11: + +```sh +pyenv install 3.11 +``` + +- Newer versions of python don't support the tensorflow version used in the project + +Switch to python 3.11: + +```sh +pyenv global 3.11 +``` + +- You can switch back after the configuration process by running `pyenv system global` + +Create a virtual environment: + +```sh +python3 -m venv venv +``` + +Activate the virtual environment: + +```sh +source venv/bin/activate +``` + +Install the project's dependencies: + +```sh +pip3 install . +``` + +Configure the absolute path to the project's `src` directory (e.g. `/home/username/Documents/SCIoT/src/`): + +```sh +cd $(python3 -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") +echo /absolute/path/to/project/src/ > project.pth +``` +### Model setup +- Save your keras model as `test_model.h5` in `src/server/models/test/test_model/` +- Save your test image as `test_image.png` in `src/server/models/test/test_model/pred_data/` +- Split the model by running `python3 model_split.py` in `src/server/models/` +- Configure the paths as needed using `src/server/commons.py` + +### Server setup +- Configure the server using `src/server/settings.yaml` + +## Usage +In root directory, run the MQTT broker: + +```sh +docker compose up +``` + +In `src/server/edge`, run the edge server: + +```sh +python3 run_edge.py +``` + +In `src/server/web`, run the webpage: + +```sh +streamlit run webpage.py +``` From f384e931ec5928a40f2b01a8a3404bafaaaba07d Mon Sep 17 00:00:00 2001 From: Jun Wu Wang Date: Mon, 28 Apr 2025 23:08:23 +0200 Subject: [PATCH 5/5] Fix tests --- tests/fixtures/mqtt_client_fixture.py | 56 +++++++++++++++++++++++++-- tests/test_edge/test_edge.py | 4 +- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/mqtt_client_fixture.py b/tests/fixtures/mqtt_client_fixture.py index 5a7bc37..b3e474c 100644 --- a/tests/fixtures/mqtt_client_fixture.py +++ b/tests/fixtures/mqtt_client_fixture.py @@ -3,8 +3,10 @@ import pytest from server.commons import OffloadingDataFiles -from server.mqtt_client.mqtt_client import MqttClient +from server.communication.mqtt_client import MqttClient from tests.commons import TestSamples +from paho.mqtt import client as mqtt +from server.communication.request_handler import RequestHandler @pytest.fixture @@ -24,9 +26,57 @@ def offloading_data_fixture(monkeypatch): @pytest.fixture def mqtt_client_fixture(offloading_data_fixture): """ Fixture to create an MQTT client with overridden file paths. """ - return MqttClient() + broker_url = 'hostname.local' + broker_port = 1883 + client_id = 'edge' + topics = { + 'registration': 'devices/', + 'offloading_layer': 'device_01/offloading_layer', + 'device_input': 'device_01/input_data', + 'device_inference_result': 'device_01/model_inference_result' + } + ntp_server = '0.it.pool.ntp.org' + input_height = 96 + input_width = 96 + last_offloading_layer = 58 + return MqttClient( + broker_url=broker_url, + broker_port=broker_port, + client_id=client_id, + protocol=mqtt.MQTTv311, + subscribed_topics=topics, + ntp_server=ntp_server, + input_height=input_height, + input_width=input_width, + last_offloading_layer=last_offloading_layer, + request_handler=RequestHandler() + ) @pytest.fixture def device_fixture(mqtt_client_fixture): - return MqttClient() + broker_url = 'hostname.local' + broker_port = 1883 + client_id = 'edge' + topics = { + 'registration': 'devices/', + 'offloading_layer': 'device_01/offloading_layer', + 'device_input': 'device_01/input_data', + 'device_inference_result': 'device_01/model_inference_result' + } + ntp_server = '0.it.pool.ntp.org' + input_height = 96 + input_width = 96 + last_offloading_layer = 58 + return MqttClient( + broker_url=broker_url, + broker_port=broker_port, + client_id=client_id, + protocol=mqtt.MQTTv311, + subscribed_topics=topics, + ntp_server=ntp_server, + input_height=input_height, + input_width=input_width, + last_offloading_layer=last_offloading_layer, + request_handler=RequestHandler() + ) diff --git a/tests/test_edge/test_edge.py b/tests/test_edge/test_edge.py index b2a9fbd..276e29f 100644 --- a/tests/test_edge/test_edge.py +++ b/tests/test_edge/test_edge.py @@ -2,4 +2,6 @@ def test_edge(): e = Edge() - e.initialization() \ No newline at end of file + height = 96 + width = 96 + e.initialization(input_height=height, input_width=width) \ No newline at end of file