Skip to content

Commit ff442bd

Browse files
committed
allow using non-sendable pressio compressors using fragile
1 parent d8944df commit ff442bd

3 files changed

Lines changed: 54 additions & 6 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,11 @@ clap = { version = "4.5", default-features = false }
9393
convert_case = { version = "0.8", default-features = false }
9494
ebcc = { version = "0.1", default-features = false }
9595
format_serde_error = { version = "0.3", default-features = false }
96+
fragile = { version = "2.0", default-features = false }
9697
indexmap = { version = "2.10", default-features = false }
9798
itertools = { version = "0.14", default-features = false }
9899
lc-framework = { version = "0.1", default-features = false }
99-
libpressio = { version = "0.1", git = "https://github.com/juntyr/libpressio-rs.git", rev = "255ed51", default-features = false }
100+
libpressio = { version = "0.1", git = "https://github.com/juntyr/libpressio-rs.git", rev = "30545dd", default-features = false }
100101
log = { version = "0.4.27", default-features = false }
101102
miniz_oxide = { version = "0.8.5", default-features = false }
102103
ndarray = { version = "0.16.1", default-features = false } # keep in sync with numpy

codecs/pressio/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ keywords = ["libpressio", "numcodecs", "compression", "encoding"]
1515
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1616

1717
[dependencies]
18+
fragile = { workspace = true }
1819
libpressio = { workspace = true, features = ["bzip2", "lua"] }
1920
ndarray = { workspace = true }
2021
numcodecs = { workspace = true }

codecs/pressio/src/lib.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::{
2323
sync::{Arc, Mutex, RwLock},
2424
};
2525

26+
use fragile::Fragile;
2627
use ndarray::{Array, ArrayView, ArrayViewMut, CowArray, IxDyn};
2728
use numcodecs::{
2829
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
@@ -67,7 +68,7 @@ impl Clone for PressioCompressor {
6768
}
6869

6970
struct PressioCompressorInner {
70-
compressor: Mutex<libpressio::PressioCompressor>,
71+
compressor: Mutex<PressioCompressorSendable>,
7172
compressor_id: String,
7273
early_config: BTreeMap<String, PressioOption>,
7374
}
@@ -76,7 +77,8 @@ impl Clone for PressioCompressorInner {
7677
#[expect(clippy::unwrap_used)]
7778
fn clone(&self) -> Self {
7879
let mut pressio = libpressio::Pressio::new().unwrap();
79-
let compressor = self.compressor.lock().unwrap();
80+
let compressor_guard = self.compressor.lock().unwrap();
81+
let compressor = compressor_guard.try_get().unwrap();
8082

8183
let mut compressor_clone = pressio.get_compressor(self.compressor_id.as_str()).unwrap();
8284
compressor_clone
@@ -86,7 +88,12 @@ impl Clone for PressioCompressorInner {
8688
.set_options(&compressor.get_options().unwrap())
8789
.unwrap();
8890

89-
std::mem::drop(compressor);
91+
std::mem::drop(compressor_guard);
92+
93+
let compressor_clone = match compressor_clone.try_into_sendable() {
94+
Ok(compressor) => PressioCompressorSendable::Sendable(compressor),
95+
Err((compressor, _err)) => PressioCompressorSendable::Fragile(Fragile::new(compressor)),
96+
};
9097

9198
Self {
9299
compressor: Mutex::new(compressor_clone),
@@ -96,6 +103,31 @@ impl Clone for PressioCompressorInner {
96103
}
97104
}
98105

106+
enum PressioCompressorSendable {
107+
Sendable(libpressio::PressioSendableCompressor),
108+
Fragile(Fragile<libpressio::PressioCompressor>),
109+
}
110+
111+
impl PressioCompressorSendable {
112+
fn try_get(&self) -> Result<&libpressio::PressioCompressor, PressioCodecError> {
113+
match self {
114+
Self::Sendable(compressor) => Ok(compressor),
115+
Self::Fragile(compressor) => compressor
116+
.try_get()
117+
.map_err(|_| PressioCodecError::PressioNonThreadsafeSend),
118+
}
119+
}
120+
121+
fn try_get_mut(&mut self) -> Result<&mut libpressio::PressioCompressor, PressioCodecError> {
122+
match self {
123+
Self::Sendable(compressor) => Ok(compressor),
124+
Self::Fragile(compressor) => compressor
125+
.try_get_mut()
126+
.map_err(|_| PressioCodecError::PressioNonThreadsafeSend),
127+
}
128+
}
129+
}
130+
99131
impl Serialize for PressioCompressor {
100132
#[expect(clippy::too_many_lines)]
101133
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
@@ -196,7 +228,10 @@ impl Serialize for PressioCompressor {
196228
}
197229

198230
let inner = self.inner.read().map_err(serde::ser::Error::custom)?;
199-
let compressor = inner.compressor.lock().map_err(serde::ser::Error::custom)?;
231+
let compressor_guard = inner.compressor.lock().map_err(serde::ser::Error::custom)?;
232+
let compressor = compressor_guard
233+
.try_get()
234+
.map_err(serde::ser::Error::custom)?;
200235
let options = compressor
201236
.get_options()
202237
.map_err(serde::ser::Error::custom)?;
@@ -216,7 +251,7 @@ impl Serialize for PressioCompressor {
216251
},
217252
}
218253
.serialize(serializer);
219-
std::mem::drop(compressor);
254+
std::mem::drop(compressor_guard);
220255
std::mem::drop(inner);
221256
result
222257
}
@@ -423,6 +458,11 @@ impl<'de> Deserialize<'de> for PressioCompressor {
423458
.set_options(&options)
424459
.map_err(serde::de::Error::custom)?;
425460

461+
let compressor = match compressor.try_into_sendable() {
462+
Ok(compressor) => PressioCompressorSendable::Sendable(compressor),
463+
Err((compressor, _err)) => PressioCompressorSendable::Fragile(Fragile::new(compressor)),
464+
};
465+
426466
Ok(Self {
427467
inner: RwLock::new(Arc::new(PressioCompressorInner {
428468
compressor: Mutex::new(compressor),
@@ -665,6 +705,7 @@ impl Codec for PressioCodec {
665705
let Ok(compressor) = Arc::make_mut(&mut inner).compressor.get_mut() else {
666706
return Err(PressioCodecError::PressioPoisonedLock);
667707
};
708+
let compressor = compressor.try_get_mut()?;
668709

669710
match data {
670711
AnyCowArray::U8(data) => encode_typed(compressor, data),
@@ -729,6 +770,7 @@ impl Codec for PressioCodec {
729770
let Ok(compressor) = Arc::make_mut(&mut inner).compressor.get_mut() else {
730771
return Err(PressioCodecError::PressioPoisonedLock);
731772
};
773+
let compressor = compressor.try_get_mut()?;
732774

733775
match encoded {
734776
AnyCowArray::U8(encoded) => decode_typed(compressor, encoded),
@@ -844,6 +886,7 @@ impl Codec for PressioCodec {
844886
let Ok(compressor) = Arc::make_mut(&mut inner).compressor.get_mut() else {
845887
return Err(PressioCodecError::PressioPoisonedLock);
846888
};
889+
let compressor = compressor.try_get_mut()?;
847890

848891
let decoded_dtype = match decoded.dtype() {
849892
AnyArrayDType::U8 => libpressio::PressioDtype::U8,
@@ -933,6 +976,9 @@ pub enum PressioCodecError {
933976
/// [`PressioCodec`] lock was poisoned
934977
#[error("Pressio lock was poisoned")]
935978
PressioPoisonedLock,
979+
/// [`PressioCodec`] was used on a different thread with a non-threadsafe compressor
980+
#[error("Pressio was used on a different thread with a non-threadsafe compressor")]
981+
PressioNonThreadsafeSend,
936982
/// [`PressioCodec`] failed to encode the data
937983
#[error("Pressio failed to encode the data")]
938984
PressioEncodeFailed {

0 commit comments

Comments
 (0)