Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 76 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,84 @@
# SCIoT project
The Split Computing on IoT (SCIoT) project provides tools to use Edge Impulse models in ESP32 devices, using split computing techniques.

---
![Unit Tests](https://github.com/UBICO/SCIoT/actions/workflows/codecov.yml/badge.svg)<br>
[![Coverage](https://codecov.io/github/UBICO/SCIoT//coverage.svg?branch=main)](https://codecov.io/gh/UBICO/SCIoT)
[![Powered by UBICO](https://img.shields.io/badge/powered%20by-UBICO-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)]()
![Unit Tests](https://github.com/UBICO/SCIoT/actions/workflows/codecov.yml/badge.svg) [![Coverage](https://codecov.io/github/UBICO/SCIoT//coverage.svg?branch=main)](https://codecov.io/gh/UBICO/SCIoT) [![Powered by UBICO](https://img.shields.io/badge/powered%20by-UBICO-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)]()

## Publications
If you use this work, please consider citing our work:
- F. Bove, S. Colli and L. Bedogni, "Performance Evaluation of Split Computing with TinyML on IoT Devices," 2024 IEEE 21st Consumer Communications & Networking Conference (CCNC), Las Vegas, NV, USA, 2024, pp. 1-6, [DOI Link](http://dx.doi.org/10.1109/CCNC51664.2024.10454775).
- F. Bove and L. Bedogni, "Smart Split: Leveraging TinyML and Split Computing for Efficient Edge AI," 2024 IEEE/ACM Symposium on Edge Computing (SEC), Rome, Italy, 2024, pp. 456-460, [DOI Link](http://dx.doi.org/10.1109/SEC62691.2024.00052).

## Configuration
Clone and go in the repository:

```sh
git clone https://github.com/UBICO/SCIoT.git
```

Install python 3.11:

```sh
pyenv install 3.11
```

- Newer versions of python don't support the tensorflow version used in the project

Switch to python 3.11:

```sh
pyenv global 3.11
```

- You can switch back after the configuration process by running `pyenv system global`

Create a virtual environment:

```sh
python3 -m venv venv
```

Activate the virtual environment:

```sh
source venv/bin/activate
```

Install the project's dependencies:

```sh
pip3 install .
```

Configure the absolute path to the project's `src` directory (e.g. `/home/username/Documents/SCIoT/src/`):

```sh
cd $(python3 -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo /absolute/path/to/project/src/ > project.pth
```
### Model setup
- Save your keras model as `test_model.h5` in `src/server/models/test/test_model/`
- Save your test image as `test_image.png` in `src/server/models/test/test_model/pred_data/`
- Split the model by running `python3 model_split.py` in `src/server/models/`
- Configure the paths as needed using `src/server/commons.py`

### Server setup
- Configure the server using `src/server/settings.yaml`

## Usage
In root directory, run the MQTT broker:

```sh
docker compose up
```

In `src/server/edge`, run the edge server:

```sh
python3 run_edge.py
```

In `src/server/web`, run the webpage:

```sh
streamlit run webpage.py
```
38 changes: 1 addition & 37 deletions src/server/commons.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from pathlib import Path
import struct

import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array


BASE_DIR = Path(__file__).resolve().parent
Expand All @@ -16,7 +12,6 @@ class OffloadingDataFiles:

class EvaluationFiles:
evaluation_file_path: str = str(BASE_DIR / "evaluations/evaluations.csv")
web_file_path: str = str(BASE_DIR / "evaluations/web.csv")


class ConfigurationFiles:
Expand All @@ -28,36 +23,5 @@ class ModelFiles:


class InputDataFiles:
test_data_file_path: str = str(BASE_DIR / "models/test/test_model/pred_data/input_data.png") # Path to test image
test_data_file_path: str = str(BASE_DIR / "models/test/test_model/pred_data/test_image.png") # Path to test image
input_data_file_path: str = str(BASE_DIR / "input_data.png") # Input image save path


class InputData:
height = 96
width = 96

def __init__(self, image_path=InputDataFiles.test_data_file_path, color_mode="rgb",
target_size=(height, width)): # Model input configuration
input_image = load_img(image_path, color_mode=color_mode, target_size=target_size)
image_array = img_to_array(input_image)
self.image_array = np.array([image_array])

@staticmethod
def make_array(rgb565_image, h=height, w=width):
image_array = []

for i in range(h):
row = []
s = rgb565_image[i * w * 2:(i + 1) * w * 2]
pixels = struct.unpack(f'>{w}H', s)
for p in pixels:
r = p >> 11
g = (p >> 5) & 0x3f
b = p & 0x1f
r = (r * 255) / 31.0
g = (g * 255) / 63.0
b = (b * 255) / 31.0
row.append([int(round(x)) for x in [r, g, b]])
image_array.append(row)

return np.array(image_array, dtype=np.uint8)
16 changes: 12 additions & 4 deletions src/server/communication/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@ def __init__(
port: int,
endpoints: dict,
ntp_server: str,
input_height: int,
input_width: int,
last_offloading_layer: int,
request_handler: RequestHandler
):
self.app = FastAPI()
self.host = host
self.port = port
self.endpoints = endpoints

self.request_handler = request_handler
self.best_offloading_layer = last_offloading_layer

self.devices = set()

# Set up model
self.input_height = input_height
self.input_width = input_width
self.best_offloading_layer = last_offloading_layer

# Set up request handler
self.request_handler = request_handler

# Set up NTP client
self.ntp_client = ntplib.NTPClient()
self.ntp_server = ntp_server
Expand Down Expand Up @@ -59,7 +67,7 @@ async def registration(data: dict):
async def device_input(request: Request):
try:
body = await request.body() # Reads raw bytes
self.request_handler.handle_device_input(body)
self.request_handler.handle_device_input(body, self.input_height, self.input_width)
return {'message': 'Success'}
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))
Expand Down
150 changes: 150 additions & 0 deletions src/server/communication/mqtt_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import json
import random
import time

import ntplib
import paho.mqtt.client as mqtt
import threading
import queue

from server.logger.log import logger
from server.communication.request_handler import RequestHandler


class MqttClient:
def __init__(
self,
broker_url: str,
broker_port: int,
client_id: str,
protocol: str,
subscribed_topics: list,
ntp_server: str,
input_height: int,
input_width: int,
last_offloading_layer: int,
request_handler: RequestHandler
):
self.broker_url = broker_url
self.broker_port = broker_port
self.client_id = client_id

# Create the client with the specific MQTT protocol version
self.client = mqtt.Client(client_id=client_id, protocol=protocol)

# Attach callbacks
self.client.on_connect = self.on_connect
self.client.on_message = self.on_message

# Set up topics
self.subscribed_topics = subscribed_topics

# Set up NTP client
self.ntp_client = ntplib.NTPClient()
self.ntp_server = ntp_server
self.offset = self.sync_with_ntp()
self.start_timestamp = self.get_current_time()

# Set up model
self.input_height = input_height
self.input_width = input_width
self.best_offloading_layer = last_offloading_layer

# Set up request handler
self.request_handler = request_handler

# Set up helper thread
self.task_queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()

@staticmethod
def create_random_payload():
"""Creates a random payload for testing."""
message = json.dumps({"id": random.randint(1, 1000)})
return message

def publish(self, topic: str, message: str, qos: int = 2):
"""Publishes a message to a topic."""
logger.debug(f"Publishing message to {topic}: {message}")
try:
self.client.publish(topic, message, qos=qos, retain=False)
except Exception as e:
logger.debug(f"Error publishing message: {e}")

def subscribe(self, topic: str):
"""Subscribes to a topic."""
logger.debug(f"Subscribing to topic: {topic}")
self.client.subscribe(topic)

def run(self):
"""Connect to the broker and start the MQTT client loop."""
self.client.connect(self.broker_url, self.broker_port, 60)
self.client.loop_forever()

def stop(self):
"""Stops the MQTT client loop and disconnects."""
logger.debug("Disconnecting MQTT client")
self.client.disconnect()

def on_connect(self, client, userdata, flags, rc):
if rc == 0:
logger.debug(f"Connected to {self.broker_url}:{self.broker_port} with client ID {self.client_id}")
for topic in self.subscribed_topics.values():
self.subscribe(topic)
logger.debug(f"Initial NTP timestamp from NTP server {self.ntp_server}: {self.start_timestamp}")
else:
logger.debug(f"Connection failed with code {rc}")

def sync_with_ntp(self) -> float:
ntp_timestamp = None
while ntp_timestamp is None:
try:
response = self.ntp_client.request(self.ntp_server)
# Get the offset between local clock time and ntp server time (seconds since 1900)
offset = response.offset
logger.debug(f"Synchronized with NTP server. Offset: {offset} seconds")
return offset
except ntplib.NTPException as _:
time.sleep(1)
threading.Timer(600, self.sync_with_ntp).start()

def get_current_time(self) -> float:
return time.time() + self.offset

def _worker(self):
while True:
task = self.task_queue.get()
task()
self.task_queue.task_done()

def on_message(self, client, userdata, message):
received_timestamp = self.get_current_time()

def task():
self.handle_message_task(message, received_timestamp)

self.task_queue.put(task) # Submit the task to the worker thread queue

def handle_message_task(self, message, received_timestamp):
if message.topic == self.subscribed_topics['device_inference_result']: # Updates device time, runs offloading algorithm and sends best offloading layer
logger.debug('Device inference result received')
self.best_offloading_layer = self.request_handler.handle_device_inference_result(body=message.payload, received_timestamp=received_timestamp)
cleaned_offloading_layer_index = self.request_handler.handle_offloading_layer(best_offloading_layer=self.best_offloading_layer)
message_data = {'offloading_layer_index': cleaned_offloading_layer_index}
self.publish(self.subscribed_topics['offloading_layer'], json.dumps(message_data))
logger.debug('Best offloading layer sent')
elif message.topic == self.subscribed_topics['device_input']: # Save input image
logger.debug('Device input received')
self.request_handler.handle_device_input(message.payload, self.input_height, self.input_width)
logger.debug('Device input saved')
elif message.topic == self.subscribed_topics['registration']: # Sends best offloading layer
logger.debug('Registration request received')
decoded_payload = message.payload.decode()
json_data = json.loads(decoded_payload)
cleaned_device_id = self.request_handler.handle_registration(json_data["device_id"])
cleaned_offloading_layer_index = self.request_handler.handle_offloading_layer(best_offloading_layer=self.best_offloading_layer)
message_data = {'offloading_layer_index': cleaned_offloading_layer_index}
self.publish(self.subscribed_topics['offloading_layer'], json.dumps(message_data))
logger.debug('Best offloading layer sent')

14 changes: 8 additions & 6 deletions src/server/communication/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import json

import numpy as np
from PIL import Image

from server.edge.edge_initialization import Edge
from server.offloading_algo.offloading_algo import OffloadingAlgo
from PIL import Image

from server.commons import OffloadingDataFiles
from server.commons import EvaluationFiles
from server.commons import InputData
from server.commons import InputDataFiles

from server.logger.log import logger

from server.communication.message_data import MessageData

from server.models.model_input_converter import ModelInputConverter

import struct


class RequestHandler():
def handle_registration(self, device_id):
return device_id

def handle_device_input(self, rgb565_image):
image_array = InputData.make_array(rgb565_image=rgb565_image)
def handle_device_input(self, rgb565_image, height, width):
image_array = ModelInputConverter.convert_rgb565_to_nparray(rgb565_image, height, width)
image = Image.fromarray(image_array, 'RGB')
image.save(InputDataFiles.input_data_file_path)
logger.debug("Input image saved")
return

def handle_device_inference_result(self, body, received_timestamp):
Expand Down
Loading