@@ -393,6 +393,12 @@ struct Peer {
393393 /// We cache a `NodeId` here to avoid serializing peers' keys every time we forward gossip
394394 /// messages in `PeerManager`. Use `Peer::set_their_node_id` to modify this field.
395395 their_node_id : Option < ( PublicKey , NodeId ) > ,
396+ /// The features provided in the peer's [`msgs::Init`] message.
397+ ///
398+ /// This is set only after we've processed the [`msgs::Init`] message and called relevant
399+ /// `peer_connected` handler methods. Thus, this field is set *iff* we've finished our
400+ /// handshake and can talk to this peer normally (though use [`Peer::handshake_complete`] to
401+ /// check this.
396402 their_features : Option < InitFeatures > ,
397403 their_net_address : Option < NetAddress > ,
398404
@@ -424,6 +430,13 @@ struct Peer {
424430}
425431
426432impl Peer {
433+ /// True after we've processed the [`msgs::Init`] message and called relevant `peer_connected`
434+ /// handler methods. Thus, this implies we've finished our handshake and can talk to this peer
435+ /// normally.
436+ fn handshake_complete ( & self ) -> bool {
437+ self . their_features . is_some ( )
438+ }
439+
427440 /// Returns true if the channel announcements/updates for the given channel should be
428441 /// forwarded to this peer.
429442 /// If we are sending our routing table to this peer and we have not yet sent channel
@@ -1877,24 +1890,21 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
18771890 // thread can be holding the peer lock if we have the global write
18781891 // lock).
18791892
1880- if let Some ( mut descriptor) = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) {
1893+ let descriptor_opt = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) ;
1894+ if let Some ( mut descriptor) = descriptor_opt {
18811895 if let Some ( peer_mutex) = peers. remove ( & descriptor) {
1896+ let mut peer = peer_mutex. lock ( ) . unwrap ( ) ;
18821897 if let Some ( msg) = msg {
18831898 log_trace ! ( self . logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with message {}" ,
18841899 log_pubkey!( node_id) ,
18851900 msg. data) ;
1886- let mut peer = peer_mutex. lock ( ) . unwrap ( ) ;
18871901 self . enqueue_message ( & mut * peer, & msg) ;
18881902 // This isn't guaranteed to work, but if there is enough free
18891903 // room in the send buffer, put the error message there...
18901904 self . do_attempt_write_data ( & mut descriptor, & mut * peer, false ) ;
1891- } else {
1892- log_trace ! ( self . logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with no message" , log_pubkey!( node_id) ) ;
18931905 }
1906+ self . do_disconnect ( descriptor, & * peer, "DisconnectPeer HandleError" ) ;
18941907 }
1895- descriptor. disconnect_socket ( ) ;
1896- self . message_handler . chan_handler . peer_disconnected ( & node_id, false ) ;
1897- self . message_handler . onion_message_handler . peer_disconnected ( & node_id, false ) ;
18981908 }
18991909 }
19001910 }
@@ -1905,6 +1915,22 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
19051915 self . disconnect_event_internal ( descriptor, false ) ;
19061916 }
19071917
1918+ fn do_disconnect ( & self , mut descriptor : Descriptor , peer : & Peer , reason : & ' static str ) {
1919+ if !peer. handshake_complete ( ) {
1920+ log_trace ! ( self . logger, "Disconnecting peer which hasn't completed handshake due to {}" , reason) ;
1921+ descriptor. disconnect_socket ( ) ;
1922+ return ;
1923+ }
1924+
1925+ debug_assert ! ( peer. their_node_id. is_some( ) ) ;
1926+ if let Some ( ( node_id, _) ) = peer. their_node_id {
1927+ log_trace ! ( self . logger, "Disconnecting peer with id {} due to {}" , node_id, reason) ;
1928+ self . message_handler . chan_handler . peer_disconnected ( & node_id, false ) ;
1929+ self . message_handler . onion_message_handler . peer_disconnected ( & node_id, false ) ;
1930+ }
1931+ descriptor. disconnect_socket ( ) ;
1932+ }
1933+
19081934 fn disconnect_event_internal ( & self , descriptor : & Descriptor , no_connection_possible : bool ) {
19091935 let mut peers = self . peers . write ( ) . unwrap ( ) ;
19101936 let peer_option = peers. remove ( descriptor) ;
@@ -1916,6 +1942,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
19161942 } ,
19171943 Some ( peer_lock) => {
19181944 let peer = peer_lock. lock ( ) . unwrap ( ) ;
1945+ if !peer. handshake_complete ( ) { return ; }
1946+ debug_assert ! ( peer. their_node_id. is_some( ) ) ;
19191947 if let Some ( ( node_id, _) ) = peer. their_node_id {
19201948 log_trace ! ( self . logger,
19211949 "Handling disconnection of peer {}, with {}future connection to the peer possible." ,
@@ -1937,14 +1965,13 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
19371965 /// peer. Thus, be very careful about reentrancy issues.
19381966 ///
19391967 /// [`disconnect_socket`]: SocketDescriptor::disconnect_socket
1940- pub fn disconnect_by_node_id ( & self , node_id : PublicKey , no_connection_possible : bool ) {
1968+ pub fn disconnect_by_node_id ( & self , node_id : PublicKey , _no_connection_possible : bool ) {
19411969 let mut peers_lock = self . peers . write ( ) . unwrap ( ) ;
1942- if let Some ( mut descriptor) = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) {
1943- log_trace ! ( self . logger, "Disconnecting peer with id {} due to client request" , node_id) ;
1944- peers_lock. remove ( & descriptor) ;
1945- self . message_handler . chan_handler . peer_disconnected ( & node_id, no_connection_possible) ;
1946- self . message_handler . onion_message_handler . peer_disconnected ( & node_id, no_connection_possible) ;
1947- descriptor. disconnect_socket ( ) ;
1970+ if let Some ( descriptor) = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) {
1971+ let peer_opt = peers_lock. remove ( & descriptor) ;
1972+ if let Some ( peer_mutex) = peer_opt {
1973+ self . do_disconnect ( descriptor, & * peer_mutex. lock ( ) . unwrap ( ) , "client request" ) ;
1974+ } else { debug_assert ! ( false , "node_id_to_descriptor thought we had a peer" ) ; }
19481975 }
19491976 }
19501977
@@ -1955,13 +1982,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
19551982 let mut peers_lock = self . peers . write ( ) . unwrap ( ) ;
19561983 self . node_id_to_descriptor . lock ( ) . unwrap ( ) . clear ( ) ;
19571984 let peers = & mut * peers_lock;
1958- for ( mut descriptor, peer) in peers. drain ( ) {
1959- if let Some ( ( node_id, _) ) = peer. lock ( ) . unwrap ( ) . their_node_id {
1960- log_trace ! ( self . logger, "Disconnecting peer with id {} due to client request to disconnect all peers" , node_id) ;
1961- self . message_handler . chan_handler . peer_disconnected ( & node_id, false ) ;
1962- self . message_handler . onion_message_handler . peer_disconnected ( & node_id, false ) ;
1963- }
1964- descriptor. disconnect_socket ( ) ;
1985+ for ( descriptor, peer_mutex) in peers. drain ( ) {
1986+ self . do_disconnect ( descriptor, & * peer_mutex. lock ( ) . unwrap ( ) , "client request to disconnect all peers" ) ;
19651987 }
19661988 }
19671989
@@ -2052,21 +2074,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
20522074 if !descriptors_needing_disconnect. is_empty ( ) {
20532075 {
20542076 let mut peers_lock = self . peers . write ( ) . unwrap ( ) ;
2055- for descriptor in descriptors_needing_disconnect. iter ( ) {
2056- if let Some ( peer ) = peers_lock. remove ( descriptor) {
2057- if let Some ( ( node_id , _ ) ) = peer . lock ( ) . unwrap ( ) . their_node_id {
2058- log_trace ! ( self . logger , "Disconnecting peer with id {} due to ping timeout" , node_id ) ;
2077+ for descriptor in descriptors_needing_disconnect {
2078+ if let Some ( peer_mutex ) = peers_lock. remove ( & descriptor) {
2079+ let peer = peer_mutex . lock ( ) . unwrap ( ) ;
2080+ if let Some ( ( node_id , _ ) ) = peer. their_node_id {
20592081 self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) ;
2060- self . message_handler . chan_handler . peer_disconnected ( & node_id, false ) ;
2061- self . message_handler . onion_message_handler . peer_disconnected ( & node_id, false ) ;
20622082 }
2083+ self . do_disconnect ( descriptor, & * peer, "ping timeout" ) ;
20632084 }
20642085 }
20652086 }
2066-
2067- for mut descriptor in descriptors_needing_disconnect. drain ( ..) {
2068- descriptor. disconnect_socket ( ) ;
2069- }
20702087 }
20712088 }
20722089
0 commit comments