@@ -125,8 +125,8 @@ impl<PH: DNSResolverMessageHandler> DNSResolverMessageHandler for OMDomainResolv
125125 let contents = DNSResolverMessage :: DNSSECProof ( DNSSECProof { name : q. 0 , proof } ) ;
126126 let instructions = responder. respond ( ) . into_instructions ( ) ;
127127 us. pending_replies . lock ( ) . unwrap ( ) . push ( ( contents, instructions) ) ;
128- us. pending_query_count . fetch_sub ( 1 , Ordering :: Relaxed ) ;
129128 }
129+ us. pending_query_count . fetch_sub ( 1 , Ordering :: Relaxed ) ;
130130 } ) ;
131131 None
132132 }
@@ -337,4 +337,95 @@ mod test {
337337 assert_eq ! ( resolution. 1 , payment_id) ;
338338 assert ! ( resolution. 2 [ .."bitcoin:" . len( ) ] . eq_ignore_ascii_case( "bitcoin:" ) ) ;
339339 }
340+
341+ #[ tokio:: test]
342+ async fn failed_query_does_not_leak_pending_counter ( ) {
343+ use std:: sync:: atomic:: Ordering ;
344+
345+ let secp_ctx = Secp256k1 :: new ( ) ;
346+
347+ // Resolver points at a port that should refuse TCP, so build_txt_proof_async
348+ // returns Err quickly.
349+ let resolver_keys = Arc :: new ( KeysManager :: new ( & [ 99 ; 32 ] , 42 , 43 , true ) ) ;
350+ let resolver_logger = TestLogger { node : "resolver" } ;
351+ let resolver =
352+ Arc :: new ( OMDomainResolver :: < IgnoringMessageHandler > :: ignoring_incoming_proofs (
353+ "127.0.0.1:1" . parse ( ) . unwrap ( ) ,
354+ ) ) ;
355+ let resolver_state = Arc :: clone ( & resolver. state ) ;
356+ let resolver_messenger = OnionMessenger :: new (
357+ Arc :: clone ( & resolver_keys) ,
358+ Arc :: clone ( & resolver_keys) ,
359+ resolver_logger,
360+ DummyNodeLookup { } ,
361+ DirectlyConnectedRouter { } ,
362+ IgnoringMessageHandler { } ,
363+ IgnoringMessageHandler { } ,
364+ Arc :: clone ( & resolver) ,
365+ IgnoringMessageHandler { } ,
366+ ) ;
367+ let resolver_id = resolver_keys. get_node_id ( Recipient :: Node ) . unwrap ( ) ;
368+
369+ let resolver_dest = Destination :: Node ( resolver_id) ;
370+ let now = SystemTime :: now ( ) . duration_since ( SystemTime :: UNIX_EPOCH ) . unwrap ( ) . as_secs ( ) ;
371+
372+ let payment_id = PaymentId ( [ 42 ; 32 ] ) ;
373+ let name = HumanReadableName :: from_encoded ( "matt@mattcorallo.com" ) . unwrap ( ) ;
374+
375+ let payer_keys = Arc :: new ( KeysManager :: new ( & [ 2 ; 32 ] , 42 , 43 , true ) ) ;
376+ let payer_logger = TestLogger { node : "payer" } ;
377+ let payer_id = payer_keys. get_node_id ( Recipient :: Node ) . unwrap ( ) ;
378+ let payer = Arc :: new ( URIResolver {
379+ resolved_uri : Mutex :: new ( None ) ,
380+ resolver : OMNameResolver :: new ( now as u32 , 1 ) ,
381+ pending_messages : Mutex :: new ( Vec :: new ( ) ) ,
382+ } ) ;
383+ let payer_messenger = Arc :: new ( OnionMessenger :: new (
384+ Arc :: clone ( & payer_keys) ,
385+ Arc :: clone ( & payer_keys) ,
386+ payer_logger,
387+ DummyNodeLookup { } ,
388+ DirectlyConnectedRouter { } ,
389+ IgnoringMessageHandler { } ,
390+ IgnoringMessageHandler { } ,
391+ Arc :: clone ( & payer) ,
392+ IgnoringMessageHandler { } ,
393+ ) ) ;
394+
395+ let init_msg = get_om_init ( ) ;
396+ payer_messenger. peer_connected ( resolver_id, & init_msg, true ) . unwrap ( ) ;
397+ resolver_messenger. peer_connected ( payer_id, & init_msg, false ) . unwrap ( ) ;
398+
399+ let ( msg, context) =
400+ payer. resolver . resolve_name ( payment_id, name. clone ( ) , & * payer_keys) . unwrap ( ) ;
401+ let query_context = MessageContext :: DNSResolver ( context) ;
402+ let receive_key = payer_keys. get_receive_auth_key ( ) ;
403+ let reply_path = BlindedMessagePath :: one_hop (
404+ payer_id,
405+ receive_key,
406+ query_context,
407+ false ,
408+ & * payer_keys,
409+ & secp_ctx,
410+ ) ;
411+ payer. pending_messages . lock ( ) . unwrap ( ) . push ( (
412+ DNSResolverMessage :: DNSSECQuery ( msg) ,
413+ MessageSendInstructions :: WithSpecifiedReplyPath {
414+ destination : resolver_dest,
415+ reply_path,
416+ } ,
417+ ) ) ;
418+
419+ let query = payer_messenger. next_onion_message_for_peer ( resolver_id) . unwrap ( ) ;
420+ resolver_messenger. handle_onion_message ( payer_id, & query) ;
421+
422+ let start = Instant :: now ( ) ;
423+ while resolver_state. pending_query_count . load ( Ordering :: Relaxed ) != 0 {
424+ tokio:: time:: sleep ( Duration :: from_millis ( 50 ) ) . await ;
425+ assert ! (
426+ start. elapsed( ) < Duration :: from_secs( 10 ) ,
427+ "pending_query_count not decremented after failed proof: counter leaks"
428+ ) ;
429+ }
430+ }
340431}
0 commit comments