Skip to content

Commit 78a90b5

Browse files
committed
Implement custom roots for LAN clients in communication.
1 parent 49ee9b9 commit 78a90b5

15 files changed

Lines changed: 431 additions & 43 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ base64 = "0.22.1"
99
bon = "3.8.0"
1010
chrono = { version = "0.4.42", features = ["serde"] }
1111
futures-util = "0.3.31"
12+
generic-array = { version = "=0.14.7", features = ["serde"] }
1213
hmac = "0.12.1"
1314
http = "1.4.0"
1415
hyper = "1.8.1"

s2energy-connection/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ axum.workspace = true
88
axum-extra.workspace = true
99
base64.workspace = true
1010
futures-util.workspace = true
11+
generic-array.workspace = true
1112
hmac.workspace = true
1213
http.workspace = true
1314
hyper.workspace = true

s2energy-connection/examples/communication-client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ impl ClientPairing for &mut MemoryPairing {
3535
&self.communication_url
3636
}
3737

38+
fn certificate_hash(&self) -> Option<s2energy_connection::CertificateHash> {
39+
None
40+
}
41+
3842
async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
3943
self.tokens = tokens;
4044
Ok(())

s2energy-connection/examples/pairing-client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async fn main() {
2929
},
3030
vec![MessageVersion("v1".into())],
3131
)
32-
.with_connection_initiate_url("client.example.com".into())
32+
.with_connection_initiate_url("https://client.example.com".into())
3333
.build()
3434
.unwrap();
3535

@@ -61,7 +61,7 @@ async fn main() {
6161
let pair_result = rx.await.unwrap();
6262

6363
match pair_result.role {
64-
s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url } => {
64+
s2energy_connection::pairing::PairingRole::CommunicationClient { initiate_url, .. } => {
6565
println!("Paired as client, url: {initiate_url}, token: {}", pair_result.token.0)
6666
}
6767
s2energy_connection::pairing::PairingRole::CommunicationServer => println!("Paired as server, token: {}", pair_result.token.0),

s2energy-connection/examples/pairing-server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async fn main() {
3535
},
3636
vec![MessageVersion("v1".into())],
3737
)
38-
.with_connection_initiate_url("test.example.com".into())
38+
.with_connection_initiate_url("https://test.example.com".into())
3939
.build()
4040
.unwrap();
4141
let app = server.get_router();

s2energy-connection/src/communication/client.rs

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ use tokio_tungstenite::{Connector, connect_async_tls_with_config, tungstenite::C
77
use tracing::{debug, trace};
88

99
use crate::{
10-
AccessToken, CommunicationProtocol, EndpointDescription, NodeId,
10+
AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, NodeId,
1111
common::negotiate_version,
1212
communication::{
1313
CommunicationResult, ConnectionInfo, Error, ErrorKind, NodeConfig, WebSocketTransport,
14+
transport::hash_checking_http_client,
1415
wire::{
1516
CommunicationDetails, CommunicationDetailsErrorMessage, InitiateConnectionRequest, InitiateConnectionResponse, UnpairRequest,
1617
},
@@ -52,6 +53,8 @@ pub trait ClientPairing: Send {
5253
fn access_tokens(&self) -> impl AsRef<[AccessToken]>;
5354
/// The communication url the client can use to contact the server.
5455
fn communication_url(&self) -> impl AsRef<str>;
56+
/// Hash of the root certificate the server uses.
57+
fn certificate_hash(&self) -> Option<CertificateHash>;
5558

5659
/// Store a new set of access tokens for the pairing.
5760
fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> impl Future<Output = Result<(), Self::Error>> + Send;
@@ -71,14 +74,17 @@ impl Client {
7174
/// upon success.
7275
#[tracing::instrument(skip_all, fields(client = %pairing.client_id(), server = %pairing.server_id()), level = tracing::Level::ERROR)]
7376
pub async fn unpair(&self, pairing: impl ClientPairing) -> CommunicationResult<()> {
74-
let client = reqwest::Client::builder()
75-
.tls_certs_merge(
76-
self.additional_certificates
77-
.iter()
78-
.filter_map(|v| reqwest::Certificate::from_der(v).ok()),
79-
)
80-
.build()
81-
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?;
77+
let client = match pairing.certificate_hash() {
78+
Some(hash) => hash_checking_http_client(hash)?,
79+
None => reqwest::Client::builder()
80+
.tls_certs_merge(
81+
self.additional_certificates
82+
.iter()
83+
.filter_map(|v| reqwest::Certificate::from_der(v).ok()),
84+
)
85+
.build()
86+
.map_err(|e| Error::new(ErrorKind::TransportFailed, e))?,
87+
};
8288

8389
let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|e| Error::new(ErrorKind::InvalidUrl, e))?;
8490

@@ -301,7 +307,7 @@ mod tests {
301307
use tokio::net::TcpListener;
302308

303309
use crate::{
304-
AccessToken, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role,
310+
AccessToken, CertificateHash, CommunicationProtocol, EndpointDescription, MessageVersion, NodeId, Role,
305311
common::wire::test::{UUID_A, UUID_B, basic_node_description},
306312
communication::{
307313
self, Client, ClientConfig, ClientPairing, ErrorKind, NodeConfig, PairingLookup, Server, ServerConfig, ServerPairing,
@@ -370,6 +376,7 @@ mod tests {
370376
server: NodeId,
371377
tokens: Arc<Mutex<Vec<AccessToken>>>,
372378
url: String,
379+
certificate_hash: Option<CertificateHash>,
373380
}
374381

375382
impl ClientPairing for &TestPairing {
@@ -391,6 +398,10 @@ mod tests {
391398
&self.url
392399
}
393400

401+
fn certificate_hash(&self) -> Option<CertificateHash> {
402+
self.certificate_hash.clone()
403+
}
404+
394405
async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
395406
*self.tokens.lock().unwrap() = tokens;
396407
Ok(())
@@ -451,6 +462,7 @@ mod tests {
451462
server: UUID_B.into(),
452463
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
453464
url: format!("https://localhost:{}/", addr.port()),
465+
certificate_hash: None,
454466
};
455467

456468
assert!(client.unpair(&pairing).await.is_ok());
@@ -483,6 +495,7 @@ mod tests {
483495
server: UUID_B.into(),
484496
tokens: Arc::new(Mutex::new(vec![AccessToken("invalidtoken".into())])),
485497
url: format!("https://localhost:{}/", addr.port()),
498+
certificate_hash: None,
486499
};
487500

488501
let error = client.unpair(&pairing).await.unwrap_err();
@@ -511,6 +524,7 @@ mod tests {
511524
server: UUID_B.into(),
512525
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
513526
url: format!("https://localhost:{}/", addr.port()),
527+
certificate_hash: None,
514528
};
515529

516530
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -577,6 +591,7 @@ mod tests {
577591
server: UUID_B.into(),
578592
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
579593
url: format!("https://localhost:{}/", addr.port()),
594+
certificate_hash: None,
580595
};
581596

582597
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -637,6 +652,7 @@ mod tests {
637652
server: UUID_B.into(),
638653
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
639654
url: format!("https://localhost:{}/", addr.port()),
655+
certificate_hash: None,
640656
};
641657

642658
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -692,6 +708,7 @@ mod tests {
692708
server: UUID_B.into(),
693709
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
694710
url: format!("https://localhost:{}/", addr.port()),
711+
certificate_hash: None,
695712
};
696713

697714
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -740,6 +757,7 @@ mod tests {
740757
server: UUID_B.into(),
741758
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
742759
url: format!("https://localhost:{}/", addr.port()),
760+
certificate_hash: None,
743761
};
744762

745763
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -792,6 +810,7 @@ mod tests {
792810
server: UUID_B.into(),
793811
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
794812
url: format!("https://localhost:{}/", addr.port()),
813+
certificate_hash: None,
795814
};
796815

797816
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
@@ -830,6 +849,7 @@ mod tests {
830849
server: UUID_B.into(),
831850
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
832851
url: format!("https://localhost:{}/", addr.port()),
852+
certificate_hash: None,
833853
};
834854

835855
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
@@ -876,6 +896,7 @@ mod tests {
876896
server: UUID_B.into(),
877897
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
878898
url: format!("https://localhost:{}/", addr.port()),
899+
certificate_hash: None,
879900
};
880901

881902
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
@@ -922,6 +943,7 @@ mod tests {
922943
server: UUID_B.into(),
923944
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
924945
url: format!("https://localhost:{}/", addr.port()),
946+
certificate_hash: None,
925947
};
926948

927949
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::NoSupportedVersion);
@@ -973,6 +995,10 @@ mod tests {
973995
&self.url
974996
}
975997

998+
fn certificate_hash(&self) -> Option<CertificateHash> {
999+
None
1000+
}
1001+
9761002
async fn set_access_tokens(&mut self, _tokens: Vec<AccessToken>) -> Result<(), Self::Error> {
9771003
Err(std::io::ErrorKind::Other.into())
9781004
}
@@ -995,6 +1021,7 @@ mod tests {
9951021
server: UUID_B.into(),
9961022
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
9971023
url: format!("https://localhost:{}/", addr.port()),
1024+
certificate_hash: None,
9981025
};
9991026

10001027
let mut client_connection = client.connect(&pairing).await.unwrap();
@@ -1042,6 +1069,7 @@ mod tests {
10421069
server: UUID_B.into(),
10431070
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
10441071
url: format!("https://localhost:{}/", addr.port()),
1072+
certificate_hash: None,
10451073
};
10461074

10471075
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError);
@@ -1076,6 +1104,7 @@ mod tests {
10761104
server: UUID_B.into(),
10771105
tokens: Arc::new(Mutex::new(vec![AccessToken("testtoken".into())])),
10781106
url: format!("https://localhost:{}/", addr.port()),
1107+
certificate_hash: None,
10791108
};
10801109

10811110
assert_eq!(client.connect(&pairing).await.unwrap_err().kind(), ErrorKind::ProtocolError);

s2energy-connection/src/communication/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@
4444
//! # use std::sync::Arc;
4545
//! # use std::convert::Infallible;
4646
//! # use s2energy_connection::communication::{NodeConfig, Client, ClientConfig, ClientPairing};
47-
//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId};
47+
//! # use s2energy_connection::{MessageVersion, AccessToken, NodeId, CertificateHash};
4848
//! struct MemoryClientPairing {
4949
//! client_id: NodeId,
5050
//! server_id: NodeId,
5151
//! communication_url: String,
5252
//! access_tokens: Vec<AccessToken>,
53+
//! certificate_hash: Option<CertificateHash>,
5354
//! }
5455
//!
5556
//! impl ClientPairing for MemoryClientPairing {
@@ -71,6 +72,10 @@
7172
//! &self.access_tokens
7273
//! }
7374
//!
75+
//! fn certificate_hash(&self) -> Option<CertificateHash> {
76+
//! self.certificate_hash.clone()
77+
//! }
78+
//!
7479
//! async fn set_access_tokens(&mut self, tokens: Vec<AccessToken>) -> Result<(), Infallible> {
7580
//! self.access_tokens = tokens;
7681
//! Ok(())
@@ -84,6 +89,7 @@
8489
//! server_id: NodeId::try_from("67e55044-10b1-426f-9247-bb680e5fe0c6").unwrap(),
8590
//! communication_url: "https://example.com".into(),
8691
//! access_tokens: vec![AccessToken("some-token-value".into())],
92+
//! certificate_hash: None,
8793
//! });
8894
//! ```
8995
//!
@@ -210,6 +216,7 @@ use crate::{EndpointDescription, MessageVersion, NodeDescription};
210216
mod client;
211217
mod error;
212218
mod server;
219+
mod transport;
213220
mod websocket;
214221
mod wire;
215222

0 commit comments

Comments
 (0)