Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ members = [
"crates/shepherd-common",
"crates/shepherd-mqtt",
"crates/shepherd-run",
"crates/shepherd-watch",
"crates/shepherd-ws",
]

Expand Down
6 changes: 3 additions & 3 deletions crates/shepherd-app/src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn start(

state
.mqttc
.publish(state.robot_control, msg)
.publish(state.robot_control, msg, false)
.await
.map_err(|e| {
ShepherdError(
Expand All @@ -53,7 +53,7 @@ async fn stop(State(state): State<ControlState>) -> ShepherdResult<()> {

state
.mqttc
.publish(state.robot_control, msg)
.publish(state.robot_control, msg, false)
.await
.map_err(|e| {
ShepherdError(
Expand All @@ -74,7 +74,7 @@ async fn reset(State(state): State<ControlState>) -> ShepherdResult<()> {

state
.mqttc
.publish(state.robot_control, msg)
.publish(state.robot_control, msg, false)
.await
.map_err(|e| {
ShepherdError(
Expand Down
2 changes: 1 addition & 1 deletion crates/shepherd-app/src/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async fn upload_file(

state
.mqttc
.publish(state.robot_control, msg)
.publish(state.robot_control, msg, false)
.await
.map_err(|e| {
ShepherdError(
Expand Down
44 changes: 44 additions & 0 deletions crates/shepherd-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub struct Config {
#[serde(default)]
pub ws: WsConfig,
#[serde(default)]
pub watch: WatchConfig,
#[serde(default)]
pub channel: ChannelConfig,
#[serde(default)]
pub path: PathConfig,
Expand Down Expand Up @@ -182,6 +184,36 @@ impl Default for WsConfig {
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WatchConfig {
#[serde(default = "default_watch_service_id")]
pub service_id: String,
#[serde(default = "default_watch_host")]
pub host: String,
#[serde(default = "default_watch_port")]
pub port: u16,
}

fn default_watch_service_id() -> String {
"shepherd-watch".to_string()
}
fn default_watch_host() -> String {
"0.0.0.0".to_string()
}
fn default_watch_port() -> u16 {
1010
}

impl Default for WatchConfig {
fn default() -> Self {
Self {
service_id: default_watch_service_id(),
host: default_watch_host(),
port: default_watch_port(),
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChannelConfig {
#[serde(default = "default_channel_robot_control")]
Expand All @@ -190,6 +222,10 @@ pub struct ChannelConfig {
pub robot_log: String,
#[serde(default = "default_channel_camera")]
pub camera: String,
#[serde(default = "default_user_status")]
pub user_state: String,
#[serde(default = "default_status")]
pub status: String,
}

fn default_channel_robot_control() -> String {
Expand All @@ -201,13 +237,21 @@ fn default_channel_robot_log() -> String {
fn default_channel_camera() -> String {
"camera".to_string()
}
fn default_user_status() -> String {
"user/state".to_string()
}
fn default_status() -> String {
"status".to_string()
}

impl Default for ChannelConfig {
fn default() -> Self {
Self {
robot_control: default_channel_robot_control(),
robot_log: default_channel_robot_log(),
camera: default_channel_camera(),
user_state: default_user_status(),
status: default_status(),
}
}
}
Expand Down
8 changes: 0 additions & 8 deletions crates/shepherd-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@ use serde::{Deserialize, Serialize};
pub mod args;
pub mod config;

/// Generate a status channel name from a service ID
pub fn status_for<S>(service_id: S) -> String
where
S: AsRef<str>,
{
format!("{}/status", service_id.as_ref())
}

#[derive(Debug, Default, PartialEq, Serialize, Deserialize, Copy, Clone)]
#[serde(rename_all = "lowercase")]
pub enum RunState {
Expand Down
52 changes: 43 additions & 9 deletions crates/shepherd-mqtt/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration};

use bytes::Bytes;
use futures::future::join_all;
use rumqttc::{AsyncClient, Event, EventLoop, MqttOptions, Packet, QoS};
use rumqttc::{AsyncClient, Event, EventLoop, LastWill, MqttOptions, Packet, QoS};
use tokio::sync::Mutex;
use tracing::{debug, warn};

use crate::{Wildcard, messages::MqttMessage};
use crate::{
Wildcard,
messages::{MqttMessage, ServiceStatus, StatusMessage},
status_for,
};

pub type MqttHandler = Box<
dyn Fn(String, Bytes) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>> + Send + Sync,
Expand Down Expand Up @@ -76,26 +80,26 @@ impl MqttAsyncClient {
Ok(())
}

pub async fn publish<T, S>(&self, topic: S, msg: T) -> anyhow::Result<()>
pub async fn publish<T, S>(&self, topic: S, msg: T, retain: bool) -> anyhow::Result<()>
where
T: MqttMessage,
S: AsRef<str>,
{
let b = serde_json::to_vec(&msg)
.map_err(|e| anyhow::anyhow!("failed to serialize message: {e}"))?;

self.publish_raw(topic, b).await?;
self.publish_raw(topic, b, retain).await?;

Ok(())
}

pub async fn publish_raw<S, V>(&self, topic: S, msg: V) -> anyhow::Result<()>
pub async fn publish_raw<S, V>(&self, topic: S, msg: V, retain: bool) -> anyhow::Result<()>
where
S: AsRef<str>,
V: Into<Vec<u8>>,
{
self.client
.publish(topic.as_ref(), QoS::AtLeastOnce, false, msg)
.publish(topic.as_ref(), QoS::AtLeastOnce, retain, msg)
.await?;

debug!("client published to topic '{}'", topic.as_ref());
Expand All @@ -106,6 +110,8 @@ impl MqttAsyncClient {

pub struct MqttEventLoop {
event_loop: EventLoop,
client: AsyncClient,
service_id: String,
registry: Arc<Mutex<Registry>>,
}

Expand Down Expand Up @@ -147,8 +153,24 @@ impl MqttEventLoop {
Self::dispatch(registry, topic, payload).await;
});
}
Event::Incoming(Packet::Connect(c)) => {
debug!("mqtt client connected with id '{}'", c.client_id);
Event::Incoming(Packet::ConnAck(_)) => {
debug!("mqtt client connected");

// generate a birth message
let birth_topic = status_for(&self.service_id);
let birth_message = serde_json::to_vec(&StatusMessage {
service: self.service_id.clone(),
status: ServiceStatus::Online,
})
.expect("birth message generation failed");

if let Err(e) = self
.client
.publish(birth_topic, QoS::AtLeastOnce, true, birth_message)
.await
{
warn!("failed to send birth message: {e}");
}
}
Event::Incoming(Packet::Disconnect) => {
debug!("mqtt client disconnected");
Expand All @@ -169,8 +191,18 @@ impl MqttClient {
where
S: AsRef<str>,
{
// generate a last will for this client
let last_will_topic = status_for(service_id.as_ref());
let last_will_message = serde_json::to_vec(&StatusMessage {
service: service_id.as_ref().to_string(),
status: ServiceStatus::Offline,
})
.expect("last will generation failed"); // this should never be able to fail
let last_will = LastWill::new(last_will_topic, last_will_message, QoS::AtLeastOnce, true);

let mut mqttoptions = MqttOptions::new(service_id.as_ref(), hostname.as_ref(), port);
mqttoptions.set_keep_alive(Duration::from_secs(5));
mqttoptions.set_last_will(last_will);

let (client, event_loop) = AsyncClient::new(mqttoptions, 10);

Expand All @@ -179,12 +211,14 @@ impl MqttClient {
debug!("initialised new mqtt client");

let wc = MqttAsyncClient {
client,
client: client.clone(),
registry: registry.clone(),
};

let we = MqttEventLoop {
event_loop,
client,
service_id: service_id.as_ref().to_string(),
registry: registry.clone(),
};

Expand Down
8 changes: 8 additions & 0 deletions crates/shepherd-mqtt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@ mod util;

pub use client::*;
pub use util::*;

/// Generate a status channel name from a service ID
pub fn status_for<S>(service_id: S) -> String
where
S: AsRef<str>,
{
format!("{}/status", service_id.as_ref())
}
18 changes: 18 additions & 0 deletions crates/shepherd-mqtt/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,21 @@ pub struct ControlMessage {
pub struct RunStatusMessage {
pub state: shepherd_common::RunState,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ServiceStatus {
Online,
Offline,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct StatusMessage {
pub service: String,
pub status: ServiceStatus,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct StatusSummary {
pub statuses: Vec<StatusMessage>,
}
5 changes: 3 additions & 2 deletions crates/shepherd-run/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Result, anyhow};
use base64::Engine;
use hopper::{Pipe, PipeMode};
use shepherd_common::{Mode, RunState, Zone, config::Config, status_for};
use shepherd_common::{Mode, RunState, Zone, config::Config};
use shepherd_mqtt::{
MqttAsyncClient, MqttClient,
messages::{ControlMessage, ControlMessageType, RunStatusMessage},
Expand Down Expand Up @@ -184,8 +184,9 @@ impl Runner {
// could be used to tell when robot is started/stopped
mqttc
.publish(
status_for(&self.config.run.service_id),
&self.config.channel.user_state,
RunStatusMessage { state: next },
false,
)
.await?;

Expand Down
16 changes: 16 additions & 0 deletions crates/shepherd-watch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "shepherd-watch"
version = "0.1.0"
edition.workspace = true
authors.workspace = true
license.workspace = true

[dependencies]
anyhow.workspace = true
futures-util.workspace = true
serde_json.workspace = true
shepherd-common = { path = "../shepherd-common" }
shepherd-mqtt = { path = "../shepherd-mqtt" }
tokio.workspace = true
tokio-tungstenite.workspace = true
tracing.workspace = true
Loading
Loading