@@ -7,10 +7,11 @@ use tokio_tungstenite::{Connector, connect_async_tls_with_config, tungstenite::C
77use tracing:: { debug, trace} ;
88
99use crate :: {
10- AccessToken , CommunicationProtocol , EndpointDescription , NodeId ,
10+ AccessToken , CertificateHash , CommunicationProtocol , EndpointDescription , NodeId ,
1111 common:: negotiate_version,
1212 communication:: {
1313 CommunicationResult , ConnectionInfo , Error , ErrorKind , NodeConfig , WebSocketTransport ,
14+ transport:: hash_checking_http_client,
1415 wire:: {
1516 CommunicationDetails , CommunicationDetailsErrorMessage , InitiateConnectionRequest , InitiateConnectionResponse , UnpairRequest ,
1617 } ,
@@ -52,6 +53,8 @@ pub trait ClientPairing: Send {
5253 fn access_tokens ( & self ) -> impl AsRef < [ AccessToken ] > ;
5354 /// The communication url the client can use to contact the server.
5455 fn communication_url ( & self ) -> impl AsRef < str > ;
56+ /// Hash of the root certificate the server uses.
57+ fn certificate_hash ( & self ) -> Option < CertificateHash > ;
5558
5659 /// Store a new set of access tokens for the pairing.
5760 fn set_access_tokens ( & mut self , tokens : Vec < AccessToken > ) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send ;
@@ -71,14 +74,17 @@ impl Client {
7174 /// upon success.
7275 #[ tracing:: instrument( skip_all, fields( client = %pairing. client_id( ) , server = %pairing. server_id( ) ) , level = tracing:: Level :: ERROR ) ]
7376 pub async fn unpair ( & self , pairing : impl ClientPairing ) -> CommunicationResult < ( ) > {
74- let client = reqwest:: Client :: builder ( )
75- . tls_certs_merge (
76- self . additional_certificates
77- . iter ( )
78- . filter_map ( |v| reqwest:: Certificate :: from_der ( v) . ok ( ) ) ,
79- )
80- . build ( )
81- . map_err ( |e| Error :: new ( ErrorKind :: TransportFailed , e) ) ?;
77+ let client = match pairing. certificate_hash ( ) {
78+ Some ( hash) => hash_checking_http_client ( hash) ?,
79+ None => reqwest:: Client :: builder ( )
80+ . tls_certs_merge (
81+ self . additional_certificates
82+ . iter ( )
83+ . filter_map ( |v| reqwest:: Certificate :: from_der ( v) . ok ( ) ) ,
84+ )
85+ . build ( )
86+ . map_err ( |e| Error :: new ( ErrorKind :: TransportFailed , e) ) ?,
87+ } ;
8288
8389 let communication_url = Url :: parse ( pairing. communication_url ( ) . as_ref ( ) ) . map_err ( |e| Error :: new ( ErrorKind :: InvalidUrl , e) ) ?;
8490
@@ -301,7 +307,7 @@ mod tests {
301307 use tokio:: net:: TcpListener ;
302308
303309 use crate :: {
304- AccessToken , CommunicationProtocol , EndpointDescription , MessageVersion , NodeId , Role ,
310+ AccessToken , CertificateHash , CommunicationProtocol , EndpointDescription , MessageVersion , NodeId , Role ,
305311 common:: wire:: test:: { UUID_A , UUID_B , basic_node_description} ,
306312 communication:: {
307313 self , Client , ClientConfig , ClientPairing , ErrorKind , NodeConfig , PairingLookup , Server , ServerConfig , ServerPairing ,
@@ -370,6 +376,7 @@ mod tests {
370376 server : NodeId ,
371377 tokens : Arc < Mutex < Vec < AccessToken > > > ,
372378 url : String ,
379+ certificate_hash : Option < CertificateHash > ,
373380 }
374381
375382 impl ClientPairing for & TestPairing {
@@ -391,6 +398,10 @@ mod tests {
391398 & self . url
392399 }
393400
401+ fn certificate_hash ( & self ) -> Option < CertificateHash > {
402+ self . certificate_hash . clone ( )
403+ }
404+
394405 async fn set_access_tokens ( & mut self , tokens : Vec < AccessToken > ) -> Result < ( ) , Self :: Error > {
395406 * self . tokens . lock ( ) . unwrap ( ) = tokens;
396407 Ok ( ( ) )
@@ -451,6 +462,7 @@ mod tests {
451462 server : UUID_B . into ( ) ,
452463 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
453464 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
465+ certificate_hash : None ,
454466 } ;
455467
456468 assert ! ( client. unpair( & pairing) . await . is_ok( ) ) ;
@@ -483,6 +495,7 @@ mod tests {
483495 server : UUID_B . into ( ) ,
484496 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "invalidtoken" . into( ) ) ] ) ) ,
485497 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
498+ certificate_hash : None ,
486499 } ;
487500
488501 let error = client. unpair ( & pairing) . await . unwrap_err ( ) ;
@@ -511,6 +524,7 @@ mod tests {
511524 server : UUID_B . into ( ) ,
512525 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
513526 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
527+ certificate_hash : None ,
514528 } ;
515529
516530 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -577,6 +591,7 @@ mod tests {
577591 server : UUID_B . into ( ) ,
578592 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
579593 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
594+ certificate_hash : None ,
580595 } ;
581596
582597 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -637,6 +652,7 @@ mod tests {
637652 server : UUID_B . into ( ) ,
638653 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
639654 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
655+ certificate_hash : None ,
640656 } ;
641657
642658 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -692,6 +708,7 @@ mod tests {
692708 server : UUID_B . into ( ) ,
693709 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
694710 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
711+ certificate_hash : None ,
695712 } ;
696713
697714 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -740,6 +757,7 @@ mod tests {
740757 server : UUID_B . into ( ) ,
741758 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
742759 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
760+ certificate_hash : None ,
743761 } ;
744762
745763 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -792,6 +810,7 @@ mod tests {
792810 server : UUID_B . into ( ) ,
793811 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
794812 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
813+ certificate_hash : None ,
795814 } ;
796815
797816 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: NoSupportedVersion ) ;
@@ -830,6 +849,7 @@ mod tests {
830849 server : UUID_B . into ( ) ,
831850 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
832851 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
852+ certificate_hash : None ,
833853 } ;
834854
835855 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: NoSupportedVersion ) ;
@@ -876,6 +896,7 @@ mod tests {
876896 server : UUID_B . into ( ) ,
877897 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
878898 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
899+ certificate_hash : None ,
879900 } ;
880901
881902 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: NoSupportedVersion ) ;
@@ -922,6 +943,7 @@ mod tests {
922943 server : UUID_B . into ( ) ,
923944 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
924945 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
946+ certificate_hash : None ,
925947 } ;
926948
927949 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: NoSupportedVersion ) ;
@@ -973,6 +995,10 @@ mod tests {
973995 & self . url
974996 }
975997
998+ fn certificate_hash ( & self ) -> Option < CertificateHash > {
999+ None
1000+ }
1001+
9761002 async fn set_access_tokens ( & mut self , _tokens : Vec < AccessToken > ) -> Result < ( ) , Self :: Error > {
9771003 Err ( std:: io:: ErrorKind :: Other . into ( ) )
9781004 }
@@ -995,6 +1021,7 @@ mod tests {
9951021 server : UUID_B . into ( ) ,
9961022 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
9971023 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
1024+ certificate_hash : None ,
9981025 } ;
9991026
10001027 let mut client_connection = client. connect ( & pairing) . await . unwrap ( ) ;
@@ -1042,6 +1069,7 @@ mod tests {
10421069 server : UUID_B . into ( ) ,
10431070 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
10441071 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
1072+ certificate_hash : None ,
10451073 } ;
10461074
10471075 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: ProtocolError ) ;
@@ -1076,6 +1104,7 @@ mod tests {
10761104 server : UUID_B . into ( ) ,
10771105 tokens : Arc :: new ( Mutex :: new ( vec ! [ AccessToken ( "testtoken" . into( ) ) ] ) ) ,
10781106 url : format ! ( "https://localhost:{}/" , addr. port( ) ) ,
1107+ certificate_hash : None ,
10791108 } ;
10801109
10811110 assert_eq ! ( client. connect( & pairing) . await . unwrap_err( ) . kind( ) , ErrorKind :: ProtocolError ) ;
0 commit comments