@@ -57,7 +57,7 @@ use crate::prelude::*;
5757use core:: cell:: RefCell ;
5858use core:: ops:: DerefMut ;
5959use core:: time:: Duration ;
60- use crate :: sync:: { Mutex , Arc , RwLock } ;
60+ use crate :: sync:: { Mutex , Arc } ;
6161use core:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
6262use core:: mem;
6363use bitcoin:: bech32:: u5;
@@ -318,25 +318,21 @@ impl<Signer: sign::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> fo
318318}
319319
320320pub ( crate ) struct TestStore {
321- persisted_bytes : RwLock < HashMap < String , HashMap < String , Arc < RwLock < Vec < u8 > > > > > > ,
321+ persisted_bytes : Mutex < HashMap < String , HashMap < String , Vec < u8 > > > > ,
322322 did_persist : Arc < AtomicBool > ,
323+ read_only : bool ,
323324}
324325
325326impl TestStore {
326- pub fn new ( ) -> Self {
327- let persisted_bytes = RwLock :: new ( HashMap :: new ( ) ) ;
327+ pub fn new ( read_only : bool ) -> Self {
328+ let persisted_bytes = Mutex :: new ( HashMap :: new ( ) ) ;
328329 let did_persist = Arc :: new ( AtomicBool :: new ( false ) ) ;
329- Self { persisted_bytes, did_persist }
330+ Self { persisted_bytes, did_persist, read_only }
330331 }
331332
332333 pub fn get_persisted_bytes ( & self , namespace : & str , key : & str ) -> Option < Vec < u8 > > {
333- if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
334- if let Some ( inner_ref) = outer_ref. get ( key) {
335- let locked = inner_ref. read ( ) . unwrap ( ) ;
336- return Some ( ( * locked) . clone ( ) ) ;
337- }
338- }
339- None
334+ let persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
335+ persisted_lock. get ( namespace) . and_then ( |e| e. get ( key) . cloned ( ) )
340336 }
341337
342338 pub fn get_and_clear_did_persist ( & self ) -> bool {
@@ -345,12 +341,14 @@ impl TestStore {
345341}
346342
347343impl KVStore for TestStore {
348- type Reader = TestReader ;
344+ type Reader = io :: Cursor < Vec < u8 > > ;
349345
350346 fn read ( & self , namespace : & str , key : & str ) -> io:: Result < Self :: Reader > {
351- if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
347+ let persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
348+ if let Some ( outer_ref) = persisted_lock. get ( namespace) {
352349 if let Some ( inner_ref) = outer_ref. get ( key) {
353- Ok ( TestReader :: new ( Arc :: clone ( inner_ref) ) )
350+ let bytes = inner_ref. clone ( ) ;
351+ Ok ( io:: Cursor :: new ( bytes) )
354352 } else {
355353 Err ( io:: Error :: new ( io:: ErrorKind :: NotFound , "Key not found" ) )
356354 }
@@ -360,53 +358,47 @@ impl KVStore for TestStore {
360358 }
361359
362360 fn write ( & self , namespace : & str , key : & str , buf : & [ u8 ] ) -> io:: Result < ( ) > {
363- let mut guard = self . persisted_bytes . write ( ) . unwrap ( ) ;
364- let outer_e = guard. entry ( namespace. to_string ( ) ) . or_insert ( HashMap :: new ( ) ) ;
365- let inner_e = outer_e. entry ( key. to_string ( ) ) . or_insert ( Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ) ;
366-
367- let mut guard = inner_e. write ( ) . unwrap ( ) ;
368- guard. write_all ( buf) ?;
361+ if self . read_only {
362+ return Err ( io:: Error :: new (
363+ io:: ErrorKind :: PermissionDenied ,
364+ "read only" ,
365+ ) ) ;
366+ }
367+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
368+ let outer_e = persisted_lock. entry ( namespace. to_string ( ) ) . or_insert ( HashMap :: new ( ) ) ;
369+ let mut bytes = Vec :: new ( ) ;
370+ bytes. write_all ( buf) ?;
371+ outer_e. insert ( key. to_string ( ) , bytes) ;
369372 self . did_persist . store ( true , Ordering :: SeqCst ) ;
370373 Ok ( ( ) )
371374 }
372375
373376 fn remove ( & self , namespace : & str , key : & str ) -> io:: Result < ( ) > {
374- match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
375- hash_map:: Entry :: Occupied ( mut e) => {
377+ if self . read_only {
378+ return Err ( io:: Error :: new (
379+ io:: ErrorKind :: PermissionDenied ,
380+ "read only" ,
381+ ) ) ;
382+ }
383+
384+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
385+ if let Some ( outer_ref) = persisted_lock. get_mut ( namespace) {
386+ outer_ref. remove ( & key. to_string ( ) ) ;
376387 self . did_persist . store ( true , Ordering :: SeqCst ) ;
377- e. get_mut ( ) . remove ( & key. to_string ( ) ) ;
378- Ok ( ( ) )
379- }
380- hash_map:: Entry :: Vacant ( _) => Ok ( ( ) ) ,
381388 }
389+
390+ Ok ( ( ) )
382391 }
383392
384393 fn list ( & self , namespace : & str ) -> io:: Result < Vec < String > > {
385- match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
394+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
395+ match persisted_lock. entry ( namespace. to_string ( ) ) {
386396 hash_map:: Entry :: Occupied ( e) => Ok ( e. get ( ) . keys ( ) . cloned ( ) . collect ( ) ) ,
387397 hash_map:: Entry :: Vacant ( _) => Ok ( Vec :: new ( ) ) ,
388398 }
389399 }
390400}
391401
392- pub struct TestReader {
393- entry_ref : Arc < RwLock < Vec < u8 > > > ,
394- }
395-
396- impl TestReader {
397- pub fn new ( entry_ref : Arc < RwLock < Vec < u8 > > > ) -> Self {
398- Self { entry_ref }
399- }
400- }
401-
402- impl io:: Read for TestReader {
403- fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
404- let bytes = self . entry_ref . read ( ) . unwrap ( ) . clone ( ) ;
405- let mut reader = io:: Cursor :: new ( bytes) ;
406- reader. read ( buf)
407- }
408- }
409-
410402pub struct TestBroadcaster {
411403 pub txn_broadcasted : Mutex < Vec < Transaction > > ,
412404 pub blocks : Arc < Mutex < Vec < ( Block , u32 ) > > > ,
0 commit comments