diff --git a/Cargo.lock b/Cargo.lock index a79c1f9..3b0b4f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1310,6 +1310,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "shepherd-watch" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures-util", + "serde_json", + "shepherd-common", + "shepherd-mqtt", + "tokio", + "tokio-tungstenite", + "tracing", +] + [[package]] name = "shepherd-ws" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index a9477f4..9836d8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crates/shepherd-common", "crates/shepherd-mqtt", "crates/shepherd-run", + "crates/shepherd-watch", "crates/shepherd-ws", ] diff --git a/crates/shepherd-app/src/control.rs b/crates/shepherd-app/src/control.rs index d4952f6..bf19c8f 100644 --- a/crates/shepherd-app/src/control.rs +++ b/crates/shepherd-app/src/control.rs @@ -32,7 +32,7 @@ async fn start( state .mqttc - .publish(state.robot_control, msg) + .publish(state.robot_control, msg, false) .await .map_err(|e| { ShepherdError( @@ -53,7 +53,7 @@ async fn stop(State(state): State) -> ShepherdResult<()> { state .mqttc - .publish(state.robot_control, msg) + .publish(state.robot_control, msg, false) .await .map_err(|e| { ShepherdError( @@ -74,7 +74,7 @@ async fn reset(State(state): State) -> ShepherdResult<()> { state .mqttc - .publish(state.robot_control, msg) + .publish(state.robot_control, msg, false) .await .map_err(|e| { ShepherdError( diff --git a/crates/shepherd-app/src/upload.rs b/crates/shepherd-app/src/upload.rs index f0af70b..5307c39 100644 --- a/crates/shepherd-app/src/upload.rs +++ b/crates/shepherd-app/src/upload.rs @@ -141,7 +141,7 @@ async fn upload_file( state .mqttc - .publish(state.robot_control, msg) + .publish(state.robot_control, msg, false) .await .map_err(|e| { ShepherdError( diff --git a/crates/shepherd-common/src/config.rs b/crates/shepherd-common/src/config.rs index 378fd01..8d888f5 100644 --- a/crates/shepherd-common/src/config.rs +++ b/crates/shepherd-common/src/config.rs @@ -19,6 +19,8 @@ pub struct Config { #[serde(default)] pub ws: WsConfig, #[serde(default)] + pub watch: WatchConfig, + #[serde(default)] pub channel: ChannelConfig, #[serde(default)] pub path: PathConfig, @@ -182,6 +184,36 @@ impl Default for WsConfig { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct WatchConfig { + #[serde(default = "default_watch_service_id")] + pub service_id: String, + #[serde(default = "default_watch_host")] + pub host: String, + #[serde(default = "default_watch_port")] + pub port: u16, +} + +fn default_watch_service_id() -> String { + "shepherd-watch".to_string() +} +fn default_watch_host() -> String { + "0.0.0.0".to_string() +} +fn default_watch_port() -> u16 { + 1010 +} + +impl Default for WatchConfig { + fn default() -> Self { + Self { + service_id: default_watch_service_id(), + host: default_watch_host(), + port: default_watch_port(), + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChannelConfig { #[serde(default = "default_channel_robot_control")] @@ -190,6 +222,10 @@ pub struct ChannelConfig { pub robot_log: String, #[serde(default = "default_channel_camera")] pub camera: String, + #[serde(default = "default_user_status")] + pub user_state: String, + #[serde(default = "default_status")] + pub status: String, } fn default_channel_robot_control() -> String { @@ -201,6 +237,12 @@ fn default_channel_robot_log() -> String { fn default_channel_camera() -> String { "camera".to_string() } +fn default_user_status() -> String { + "user/state".to_string() +} +fn default_status() -> String { + "status".to_string() +} impl Default for ChannelConfig { fn default() -> Self { @@ -208,6 +250,8 @@ impl Default for ChannelConfig { robot_control: default_channel_robot_control(), robot_log: default_channel_robot_log(), camera: default_channel_camera(), + user_state: default_user_status(), + status: default_status(), } } } diff --git a/crates/shepherd-common/src/lib.rs b/crates/shepherd-common/src/lib.rs index d733c80..fb455e9 100644 --- a/crates/shepherd-common/src/lib.rs +++ b/crates/shepherd-common/src/lib.rs @@ -3,14 +3,6 @@ use serde::{Deserialize, Serialize}; pub mod args; pub mod config; -/// Generate a status channel name from a service ID -pub fn status_for(service_id: S) -> String -where - S: AsRef, -{ - format!("{}/status", service_id.as_ref()) -} - #[derive(Debug, Default, PartialEq, Serialize, Deserialize, Copy, Clone)] #[serde(rename_all = "lowercase")] pub enum RunState { diff --git a/crates/shepherd-mqtt/src/client.rs b/crates/shepherd-mqtt/src/client.rs index 404ee81..a43c0a7 100644 --- a/crates/shepherd-mqtt/src/client.rs +++ b/crates/shepherd-mqtt/src/client.rs @@ -2,11 +2,15 @@ use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use bytes::Bytes; use futures::future::join_all; -use rumqttc::{AsyncClient, Event, EventLoop, MqttOptions, Packet, QoS}; +use rumqttc::{AsyncClient, Event, EventLoop, LastWill, MqttOptions, Packet, QoS}; use tokio::sync::Mutex; use tracing::{debug, warn}; -use crate::{Wildcard, messages::MqttMessage}; +use crate::{ + Wildcard, + messages::{MqttMessage, ServiceStatus, StatusMessage}, + status_for, +}; pub type MqttHandler = Box< dyn Fn(String, Bytes) -> Pin> + Send>> + Send + Sync, @@ -76,7 +80,7 @@ impl MqttAsyncClient { Ok(()) } - pub async fn publish(&self, topic: S, msg: T) -> anyhow::Result<()> + pub async fn publish(&self, topic: S, msg: T, retain: bool) -> anyhow::Result<()> where T: MqttMessage, S: AsRef, @@ -84,18 +88,18 @@ impl MqttAsyncClient { let b = serde_json::to_vec(&msg) .map_err(|e| anyhow::anyhow!("failed to serialize message: {e}"))?; - self.publish_raw(topic, b).await?; + self.publish_raw(topic, b, retain).await?; Ok(()) } - pub async fn publish_raw(&self, topic: S, msg: V) -> anyhow::Result<()> + pub async fn publish_raw(&self, topic: S, msg: V, retain: bool) -> anyhow::Result<()> where S: AsRef, V: Into>, { self.client - .publish(topic.as_ref(), QoS::AtLeastOnce, false, msg) + .publish(topic.as_ref(), QoS::AtLeastOnce, retain, msg) .await?; debug!("client published to topic '{}'", topic.as_ref()); @@ -106,6 +110,8 @@ impl MqttAsyncClient { pub struct MqttEventLoop { event_loop: EventLoop, + client: AsyncClient, + service_id: String, registry: Arc>, } @@ -147,8 +153,24 @@ impl MqttEventLoop { Self::dispatch(registry, topic, payload).await; }); } - Event::Incoming(Packet::Connect(c)) => { - debug!("mqtt client connected with id '{}'", c.client_id); + Event::Incoming(Packet::ConnAck(_)) => { + debug!("mqtt client connected"); + + // generate a birth message + let birth_topic = status_for(&self.service_id); + let birth_message = serde_json::to_vec(&StatusMessage { + service: self.service_id.clone(), + status: ServiceStatus::Online, + }) + .expect("birth message generation failed"); + + if let Err(e) = self + .client + .publish(birth_topic, QoS::AtLeastOnce, true, birth_message) + .await + { + warn!("failed to send birth message: {e}"); + } } Event::Incoming(Packet::Disconnect) => { debug!("mqtt client disconnected"); @@ -169,8 +191,18 @@ impl MqttClient { where S: AsRef, { + // generate a last will for this client + let last_will_topic = status_for(service_id.as_ref()); + let last_will_message = serde_json::to_vec(&StatusMessage { + service: service_id.as_ref().to_string(), + status: ServiceStatus::Offline, + }) + .expect("last will generation failed"); // this should never be able to fail + let last_will = LastWill::new(last_will_topic, last_will_message, QoS::AtLeastOnce, true); + let mut mqttoptions = MqttOptions::new(service_id.as_ref(), hostname.as_ref(), port); mqttoptions.set_keep_alive(Duration::from_secs(5)); + mqttoptions.set_last_will(last_will); let (client, event_loop) = AsyncClient::new(mqttoptions, 10); @@ -179,12 +211,14 @@ impl MqttClient { debug!("initialised new mqtt client"); let wc = MqttAsyncClient { - client, + client: client.clone(), registry: registry.clone(), }; let we = MqttEventLoop { event_loop, + client, + service_id: service_id.as_ref().to_string(), registry: registry.clone(), }; diff --git a/crates/shepherd-mqtt/src/lib.rs b/crates/shepherd-mqtt/src/lib.rs index e26f555..4cd0821 100644 --- a/crates/shepherd-mqtt/src/lib.rs +++ b/crates/shepherd-mqtt/src/lib.rs @@ -4,3 +4,11 @@ mod util; pub use client::*; pub use util::*; + +/// Generate a status channel name from a service ID +pub fn status_for(service_id: S) -> String +where + S: AsRef, +{ + format!("{}/status", service_id.as_ref()) +} diff --git a/crates/shepherd-mqtt/src/messages.rs b/crates/shepherd-mqtt/src/messages.rs index 396a41a..410850c 100644 --- a/crates/shepherd-mqtt/src/messages.rs +++ b/crates/shepherd-mqtt/src/messages.rs @@ -23,3 +23,21 @@ pub struct ControlMessage { pub struct RunStatusMessage { pub state: shepherd_common::RunState, } + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ServiceStatus { + Online, + Offline, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct StatusMessage { + pub service: String, + pub status: ServiceStatus, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct StatusSummary { + pub statuses: Vec, +} diff --git a/crates/shepherd-run/src/runner.rs b/crates/shepherd-run/src/runner.rs index 8a2688d..58a162b 100644 --- a/crates/shepherd-run/src/runner.rs +++ b/crates/shepherd-run/src/runner.rs @@ -3,7 +3,7 @@ use std::{path::PathBuf, sync::Arc, time::Duration}; use anyhow::{Result, anyhow}; use base64::Engine; use hopper::{Pipe, PipeMode}; -use shepherd_common::{Mode, RunState, Zone, config::Config, status_for}; +use shepherd_common::{Mode, RunState, Zone, config::Config}; use shepherd_mqtt::{ MqttAsyncClient, MqttClient, messages::{ControlMessage, ControlMessageType, RunStatusMessage}, @@ -184,8 +184,9 @@ impl Runner { // could be used to tell when robot is started/stopped mqttc .publish( - status_for(&self.config.run.service_id), + &self.config.channel.user_state, RunStatusMessage { state: next }, + false, ) .await?; diff --git a/crates/shepherd-watch/Cargo.toml b/crates/shepherd-watch/Cargo.toml new file mode 100644 index 0000000..30289ca --- /dev/null +++ b/crates/shepherd-watch/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "shepherd-watch" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +futures-util.workspace = true +serde_json.workspace = true +shepherd-common = { path = "../shepherd-common" } +shepherd-mqtt = { path = "../shepherd-mqtt" } +tokio.workspace = true +tokio-tungstenite.workspace = true +tracing.workspace = true diff --git a/crates/shepherd-watch/src/main.rs b/crates/shepherd-watch/src/main.rs new file mode 100644 index 0000000..030edc0 --- /dev/null +++ b/crates/shepherd-watch/src/main.rs @@ -0,0 +1,192 @@ +use std::{collections::HashMap, sync::Arc}; + +use anyhow::Result; +use futures_util::{SinkExt, StreamExt}; +use shepherd_common::{args::call_with_args, config::Config}; +use shepherd_mqtt::{ + MqttAsyncClient, MqttClient, + messages::{ServiceStatus, StatusMessage, StatusSummary}, +}; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::{ + Mutex, + broadcast::{self, Receiver, Sender}, + }, +}; +use tokio_tungstenite::{accept_async, tungstenite::Message}; +use tracing::{debug, error, info, warn}; + +async fn create_summary(statuses: Arc>>) -> Result { + let statuses = statuses.lock().await; + let status_arr: Vec = statuses + .iter() + .map(|(service, status)| StatusMessage { + service: service.clone(), + status: status.clone(), + }) + .collect(); + Ok(serde_json::to_string(&StatusSummary { + statuses: status_arr, + })?) +} + +async fn handle_websocket( + stream: TcpStream, + summary: String, + mut status_receiver: Receiver, +) -> Result<()> { + let addr = stream.peer_addr()?; + debug!("new websocket connection from {:?}", addr); + let (mut ws_tx, mut ws_rx) = accept_async(stream).await?.split(); + + // send an initial summary + ws_tx.send(Message::text(summary)).await?; + + loop { + tokio::select! { + // forward summary messages + msg = status_receiver.recv() => { + match msg { + Ok(s) => ws_tx.send(Message::text(s)).await?, + Err(e) => return Err(e)?, + } + } + + msg = ws_rx.next() => { + // detect if connection has been closed + match msg { + Some(Ok(msg)) => if msg.is_close() { + info!("closed connection from {:?}", addr); + return Ok(()) + } + None => { + info!("closed connection from {:?}", addr); + return Ok(()) + } + _ => {} + } + } + } + } +} + +async fn handle_status_message( + statuses: Arc>>, + status_sender: Sender, + mqtt_client: MqttAsyncClient, + status_topic: String, + message: StatusMessage, +) -> Result<()> { + info!("status for {}: {:?}", message.service, message.status); + + // update status table, generate summary array + let status_arr: Vec = { + let mut statuses = statuses.lock().await; + statuses.insert(message.service, message.status); + + statuses + .iter() + .map(|(service, status)| StatusMessage { + service: service.clone(), + status: status.clone(), + }) + .collect() + + // drop the lock here before sending + }; + + let summary = StatusSummary { + statuses: status_arr, + }; + + match serde_json::to_string(&summary) { + Ok(summary) => { + let _ = status_sender.send(summary.clone()); + } + Err(e) => { + warn!("failed to serialise status summary: {e}"); + } + } + + let _ = mqtt_client.publish(status_topic, summary, true).await; + + Ok(()) +} + +async fn _main(config: Config) -> Result<()> { + let (status_sender, _) = broadcast::channel::(64); + let statuses: Arc>> = Arc::new(Mutex::new(HashMap::new())); + + let (mut mqtt_client, mut mqtt_event_loop) = MqttClient::new( + &config.watch.service_id, + &config.mqtt.broker, + config.mqtt.port, + ); + + // TODO: wrap these in a context object + let mqtt_statuses = statuses.clone(); + let mqtt_status_sender = status_sender.clone(); + let mqtt_status = config.channel.status.clone(); + let mqtt_mqtt_client = mqtt_client.clone(); + mqtt_client + .subscribe("+/status", move |_, v| { + let mqtt_statuses = mqtt_statuses.clone(); + let mqtt_status_sender = mqtt_status_sender.clone(); + let mqtt_status = mqtt_status.clone(); + let mqtt_mqtt_client = mqtt_mqtt_client.clone(); + + async move { + handle_status_message( + mqtt_statuses, + mqtt_status_sender, + mqtt_mqtt_client, + mqtt_status, + v, + ) + .await + } + }) + .await?; + + // run mqtt event loop independently + let mqtt_loop = tokio::spawn(async move { + loop { + if let Err(e) = mqtt_event_loop.run().await { + error!("mqtt loop exited: {e}"); + } + } + }); + + let listener = + TcpListener::bind(format!("{}:{}", &config.watch.host, config.watch.port)).await?; + + tokio::select! { + res = async { + loop { + match listener.accept().await { + Ok((stream, _)) => { + // default to a blank string if serialisation failed + let summary = create_summary(statuses.clone()).await.unwrap_or("".to_string()); + tokio::spawn(handle_websocket(stream, summary, status_sender.subscribe()) ); + } + Err(e) => return Err(e), + } + } + } => { + warn!("websocket handler exited {:?}", res); + res? + } + + _ = mqtt_loop => { + error!("mqtt client exited?"); + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() { + call_with_args("shepherd-watch", _main).await; +} diff --git a/crates/shepherd-ws/src/main.rs b/crates/shepherd-ws/src/main.rs index b81bf40..2c12a1c 100644 --- a/crates/shepherd-ws/src/main.rs +++ b/crates/shepherd-ws/src/main.rs @@ -62,19 +62,14 @@ async fn _main(config: Config) -> Result<()> { // set up subscription for all mqtt messages let mqtt_sender = msg_sender.clone(); let mqtt_log_handle = log_handle.clone(); + let mqtt_user_status = config.channel.user_state.clone(); mqtt_client .subscribe_raw("#", move |t, v| { let mqtt_sender = mqtt_sender.clone(); let mqtt_log_handle = mqtt_log_handle.clone(); + let mqtt_user_status = mqtt_user_status.clone(); async move { - dispatch_mqtt_message( - mqtt_sender, - mqtt_log_handle, - t, - "shepherd-run/status".to_string(), - v, - ) - .await + dispatch_mqtt_message(mqtt_sender, mqtt_log_handle, t, mqtt_user_status, v).await } }) .await?;