Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ base64 = "0.22.1"
bon = "3.8.0"
chrono = { version = "0.4.42", features = ["serde"] }
futures-util = "0.3.31"
generic-array = { version = "=0.14.7", features = ["serde"] }
hmac = "0.12.1"
http = "1.4.0"
hyper = "1.8.1"
Expand Down
1 change: 1 addition & 0 deletions s2energy-connection/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ axum.workspace = true
axum-extra.workspace = true
base64.workspace = true
futures-util.workspace = true
generic-array.workspace = true
hmac.workspace = true
http.workspace = true
hyper.workspace = true
Expand Down
4 changes: 4 additions & 0 deletions s2energy-connection/examples/communication-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ impl ClientPairing for &mut MemoryPairing {
&self.communication_url
}

fn certificate_hash(&self) -> Option<s2energy_connection::CertificateHash> {
None
}

async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
self.tokens = tokens;
Ok(())
Expand Down
8 changes: 4 additions & 4 deletions s2energy-connection/examples/pairing-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use uuid::uuid;

use s2energy_connection::{
Deployment, EndpointDescription, MessageVersion, NodeDescription, Role,
pairing::{Client, ClientConfig, NodeConfig, NodeIdAlias, PairingRemote},
pairing::{Client, ClientConfig, NodeConfig, NodeIdAlias, PairingRemote, RemoteNodeIdentifier},
};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};

Expand All @@ -29,7 +29,7 @@ async fn main() {
},
vec![MessageVersion("v1".into())],
)
.with_connection_initiate_url("client.example.com".into())
.with_connection_initiate_url("https://client.example.com".into())
.build()
.unwrap();

Expand All @@ -48,7 +48,7 @@ async fn main() {
&config,
PairingRemote {
url: "https://localhost:8005".into(),
id: Some(NodeIdAlias("ninechars".into())),
id: RemoteNodeIdentifier::Alias(NodeIdAlias("ninechars".into())),
},
PAIRING_TOKEN,
async |pairing| {
Expand All @@ -61,7 +61,7 @@ async fn main() {
let pair_result = rx.await.unwrap();

match pair_result.role {
s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url } => {
s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url, .. } => {
println!("Paired as client, url: {initiate_url}, token: {}", pair_result.token.0)
}
s2energy_connection::pairing::PairingRole::CommunicationServer => println!("Paired as server, token: {}", pair_result.token.0),
Expand Down
2 changes: 1 addition & 1 deletion s2energy-connection/examples/pairing-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() {
},
vec![MessageVersion("v1".into())],
)
.with_connection_initiate_url("test.example.com".into())
.with_connection_initiate_url("https://test.example.com".into())
.build()
.unwrap();
let app = server.get_router();
Expand Down
49 changes: 39 additions & 10 deletions s2energy-connection/src/communication/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ use tokio_tungstenite::{Connector, connect_async_tls_with_config, tungstenite::C
use tracing::{debug, trace};

use crate::{
AccessToken, CommunicationProtocol, EndpointDescription, NodeId,
AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, NodeId,
common::negotiate_version,
communication::{
CommunicationResult, ConnectionInfo, Error, ErrorKind, NodeConfig, WebSocketTransport,
transport::hash_checking_http_client,
wire::{
CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse, UnpairRequest,
},
Expand Down Expand Up @@ -52,6 +53,8 @@ pub trait ClientPairing: Send {
fn access_tokens(&self) -> impl AsRef<[AccessToken]>;
/// The communication url the client can use to contact the server.
fn communication_url(&self) -> impl AsRef<str>;
/// Hash of the root certificate the server uses.
fn certificate_hash(&self) -> Option<CertificateHash>;

/// Store a new set of access tokens for the pairing.
fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> impl Future<Output = Result<(), Self::Error>> + Send;
Expand All @@ -71,14 +74,17 @@ impl Client {
/// upon success.
#[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)]
pub async fn unpair(&self, pairing: impl ClientPairing) -> CommunicationResult<()> {
let client = reqwest::Client::builder()
.tls_certs_merge(
self.additional_certificates
.iter()
.filter_map(|v| reqwest::Certificate::from_der(v).ok()),
)
.build()
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;
let client = match pairing.certificate_hash() {
Some(hash) => hash_checking_http_client(hash)?,
None => reqwest::Client::builder()
.tls_certs_merge(
self.additional_certificates
.iter()
.filter_map(|v| reqwest::Certificate::from_der(v).ok()),
)
.build()
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?,
};

let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?;

Expand Down Expand Up @@ -301,7 +307,7 @@ mod tests {
use tokio::net::TcpListener;

use crate::{
AccessToken, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role,
AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role,
common::wire::test::{UUID_A, UUID_B, basic_node_description},
communication::{
self, Client, ClientConfig, ClientPairing, ErrorKind, NodeConfig, PairingLookup, Server, ServerConfig, ServerPairing,
Expand Down Expand Up @@ -370,6 +376,7 @@ mod tests {
server: NodeId,
tokens: Arc<Mutex<Vec<AccessToken>>>,
url: String,
certificate_hash: Option<CertificateHash>,
}

impl ClientPairing for &TestPairing {
Expand All @@ -391,6 +398,10 @@ mod tests {
&self.url
}

fn certificate_hash(&self) -> Option<CertificateHash> {
self.certificate_hash.clone()
}

async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
*self.tokens.lock().unwrap() = tokens;
Ok(())
Expand Down Expand Up @@ -451,6 +462,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert!(client.unpair(&pairing).await.is_ok());
Expand Down Expand Up @@ -483,6 +495,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("invalidtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let error = client.unpair(&pairing).await.unwrap_err();
Expand Down Expand Up @@ -511,6 +524,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -577,6 +591,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -637,6 +652,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -692,6 +708,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -740,6 +757,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -792,6 +810,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
Expand Down Expand Up @@ -830,6 +849,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
Expand Down Expand Up @@ -876,6 +896,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
Expand Down Expand Up @@ -922,6 +943,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
Expand Down Expand Up @@ -973,6 +995,10 @@ mod tests {
&self.url
}

fn certificate_hash(&self) -> Option<CertificateHash> {
None
}

async fn set_access_tokens(&mut self, _tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
Err(std::io::ErrorKind::Other.into())
}
Expand All @@ -995,6 +1021,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

let mut client_connection = client.connect(&pairing).await.unwrap();
Expand Down Expand Up @@ -1042,6 +1069,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError);
Expand Down Expand Up @@ -1076,6 +1104,7 @@ mod tests {
server: UUID_B.into(),
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
url: format!("https://localhost:{}/", addr.port()),
certificate_hash: None,
};

assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError);
Expand Down
9 changes: 8 additions & 1 deletion s2energy-connection/src/communication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@
//! # use std::sync::Arc;
//! # use std::convert::Infallible;
//! # use s2energy_connection::communication::{NodeConfig, Client, ClientConfig, ClientPairing};
//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId};
//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId, CertificateHash};
//! struct MemoryClientPairing {
//! client_id: NodeId,
//! server_id: NodeId,
//! communication_url: String,
//! access_tokens: Vec<AccessToken>,
//! certificate_hash: Option<CertificateHash>,
//! }
//!
//! impl ClientPairing for MemoryClientPairing {
Expand All @@ -71,6 +72,10 @@
//! &self.access_tokens
//! }
//!
//! fn certificate_hash(&self) -> Option<CertificateHash> {
//! self.certificate_hash.clone()
//! }
//!
//! async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Infallible> {
//! self.access_tokens = tokens;
//! Ok(())
Expand All @@ -84,6 +89,7 @@
//! server_id: NodeId::try_from("67e55044-10b1-426f-9247-bb680e5fe0c6").unwrap(),
//! communication_url: "https://example.com".into(),
//! access_tokens: vec![AccessToken("some-token-value".into())],
//! certificate_hash: None,
//! });
//! ```
//!
Expand Down Expand Up @@ -210,6 +216,7 @@ use crate::{EndpointDescription, MessageVersion, NodeDescription};
mod client;
mod error;
mod server;
mod transport;
mod websocket;
mod wire;

Expand Down
Loading
Loading