@@ -23,6 +23,7 @@ use std::{
2323 sync:: { Arc , Mutex , RwLock } ,
2424} ;
2525
26+ use fragile:: Fragile ;
2627use ndarray:: { Array , ArrayView , ArrayViewMut , CowArray , IxDyn } ;
2728use numcodecs:: {
2829 AnyArray , AnyArrayAssignError , AnyArrayDType , AnyArrayView , AnyArrayViewMut , AnyCowArray ,
@@ -67,7 +68,7 @@ impl Clone for PressioCompressor {
6768}
6869
6970struct PressioCompressorInner {
70- compressor : Mutex < libpressio :: PressioCompressor > ,
71+ compressor : Mutex < PressioCompressorSendable > ,
7172 compressor_id : String ,
7273 early_config : BTreeMap < String , PressioOption > ,
7374}
@@ -76,7 +77,8 @@ impl Clone for PressioCompressorInner {
7677 #[ expect( clippy:: unwrap_used) ]
7778 fn clone ( & self ) -> Self {
7879 let mut pressio = libpressio:: Pressio :: new ( ) . unwrap ( ) ;
79- let compressor = self . compressor . lock ( ) . unwrap ( ) ;
80+ let compressor_guard = self . compressor . lock ( ) . unwrap ( ) ;
81+ let compressor = compressor_guard. try_get ( ) . unwrap ( ) ;
8082
8183 let mut compressor_clone = pressio. get_compressor ( self . compressor_id . as_str ( ) ) . unwrap ( ) ;
8284 compressor_clone
@@ -86,7 +88,12 @@ impl Clone for PressioCompressorInner {
8688 . set_options ( & compressor. get_options ( ) . unwrap ( ) )
8789 . unwrap ( ) ;
8890
89- std:: mem:: drop ( compressor) ;
91+ std:: mem:: drop ( compressor_guard) ;
92+
93+ let compressor_clone = match compressor_clone. try_into_sendable ( ) {
94+ Ok ( compressor) => PressioCompressorSendable :: Sendable ( compressor) ,
95+ Err ( ( compressor, _err) ) => PressioCompressorSendable :: Fragile ( Fragile :: new ( compressor) ) ,
96+ } ;
9097
9198 Self {
9299 compressor : Mutex :: new ( compressor_clone) ,
@@ -96,6 +103,31 @@ impl Clone for PressioCompressorInner {
96103 }
97104}
98105
106+ enum PressioCompressorSendable {
107+ Sendable ( libpressio:: PressioSendableCompressor ) ,
108+ Fragile ( Fragile < libpressio:: PressioCompressor > ) ,
109+ }
110+
111+ impl PressioCompressorSendable {
112+ fn try_get ( & self ) -> Result < & libpressio:: PressioCompressor , PressioCodecError > {
113+ match self {
114+ Self :: Sendable ( compressor) => Ok ( compressor) ,
115+ Self :: Fragile ( compressor) => compressor
116+ . try_get ( )
117+ . map_err ( |_| PressioCodecError :: PressioNonThreadsafeSend ) ,
118+ }
119+ }
120+
121+ fn try_get_mut ( & mut self ) -> Result < & mut libpressio:: PressioCompressor , PressioCodecError > {
122+ match self {
123+ Self :: Sendable ( compressor) => Ok ( compressor) ,
124+ Self :: Fragile ( compressor) => compressor
125+ . try_get_mut ( )
126+ . map_err ( |_| PressioCodecError :: PressioNonThreadsafeSend ) ,
127+ }
128+ }
129+ }
130+
99131impl Serialize for PressioCompressor {
100132 #[ expect( clippy:: too_many_lines) ]
101133 fn serialize < S : Serializer > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error > {
@@ -196,7 +228,10 @@ impl Serialize for PressioCompressor {
196228 }
197229
198230 let inner = self . inner . read ( ) . map_err ( serde:: ser:: Error :: custom) ?;
199- let compressor = inner. compressor . lock ( ) . map_err ( serde:: ser:: Error :: custom) ?;
231+ let compressor_guard = inner. compressor . lock ( ) . map_err ( serde:: ser:: Error :: custom) ?;
232+ let compressor = compressor_guard
233+ . try_get ( )
234+ . map_err ( serde:: ser:: Error :: custom) ?;
200235 let options = compressor
201236 . get_options ( )
202237 . map_err ( serde:: ser:: Error :: custom) ?;
@@ -216,7 +251,7 @@ impl Serialize for PressioCompressor {
216251 } ,
217252 }
218253 . serialize ( serializer) ;
219- std:: mem:: drop ( compressor ) ;
254+ std:: mem:: drop ( compressor_guard ) ;
220255 std:: mem:: drop ( inner) ;
221256 result
222257 }
@@ -423,6 +458,11 @@ impl<'de> Deserialize<'de> for PressioCompressor {
423458 . set_options ( & options)
424459 . map_err ( serde:: de:: Error :: custom) ?;
425460
461+ let compressor = match compressor. try_into_sendable ( ) {
462+ Ok ( compressor) => PressioCompressorSendable :: Sendable ( compressor) ,
463+ Err ( ( compressor, _err) ) => PressioCompressorSendable :: Fragile ( Fragile :: new ( compressor) ) ,
464+ } ;
465+
426466 Ok ( Self {
427467 inner : RwLock :: new ( Arc :: new ( PressioCompressorInner {
428468 compressor : Mutex :: new ( compressor) ,
@@ -665,6 +705,7 @@ impl Codec for PressioCodec {
665705 let Ok ( compressor) = Arc :: make_mut ( & mut inner) . compressor . get_mut ( ) else {
666706 return Err ( PressioCodecError :: PressioPoisonedLock ) ;
667707 } ;
708+ let compressor = compressor. try_get_mut ( ) ?;
668709
669710 match data {
670711 AnyCowArray :: U8 ( data) => encode_typed ( compressor, data) ,
@@ -729,6 +770,7 @@ impl Codec for PressioCodec {
729770 let Ok ( compressor) = Arc :: make_mut ( & mut inner) . compressor . get_mut ( ) else {
730771 return Err ( PressioCodecError :: PressioPoisonedLock ) ;
731772 } ;
773+ let compressor = compressor. try_get_mut ( ) ?;
732774
733775 match encoded {
734776 AnyCowArray :: U8 ( encoded) => decode_typed ( compressor, encoded) ,
@@ -844,6 +886,7 @@ impl Codec for PressioCodec {
844886 let Ok ( compressor) = Arc :: make_mut ( & mut inner) . compressor . get_mut ( ) else {
845887 return Err ( PressioCodecError :: PressioPoisonedLock ) ;
846888 } ;
889+ let compressor = compressor. try_get_mut ( ) ?;
847890
848891 let decoded_dtype = match decoded. dtype ( ) {
849892 AnyArrayDType :: U8 => libpressio:: PressioDtype :: U8 ,
@@ -933,6 +976,9 @@ pub enum PressioCodecError {
933976 /// [`PressioCodec`] lock was poisoned
934977 #[ error( "Pressio lock was poisoned" ) ]
935978 PressioPoisonedLock ,
979+ /// [`PressioCodec`] was used on a different thread with a non-threadsafe compressor
980+ #[ error( "Pressio was used on a different thread with a non-threadsafe compressor" ) ]
981+ PressioNonThreadsafeSend ,
936982 /// [`PressioCodec`] failed to encode the data
937983 #[ error( "Pressio failed to encode the data" ) ]
938984 PressioEncodeFailed {
0 commit comments