diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 30d0bc71..abd25592 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -532,10 +532,10 @@ def load_item_from_chunk( if self._serializer_name == "no_header_tensor": # count: number of tokens to read from buffer => `self._block_size` - data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset).clone() else: # count: number of tokens to read from buffer => `self._block_size` - data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore + data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset).copy() # type: ignore return data diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index ab2ea856..4818baee 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -14,7 +14,6 @@ import glob import logging import os -import warnings from contextlib import suppress from datetime import datetime from queue import Empty, Queue @@ -33,9 +32,6 @@ from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _WorkerEnv -warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*") - - logger = logging.getLogger("litdata.streaming.reader") @@ -332,7 +328,6 @@ def __init__( """ super().__init__() - warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*") self._cache_dir = cache_dir self._remote_input_dir = remote_input_dir diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 12e587b2..3953f24d 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -134,7 +134,7 @@ def serialize(self, item: Any) -> tuple[bytes, str | None]: def deserialize(self, data: bytes) -> torch.Tensor: from torchvision.io import decode_image, decode_jpeg - array = torch.frombuffer(data, dtype=torch.uint8) + array = torch.frombuffer(bytearray(data), dtype=torch.uint8) # Try decoding as JPEG. Some datasets (e.g., ImageNet) may have PNG images with a JPEG extension, # which will cause decode_jpeg to fail. In that case, fall back to a generic image decoder. with suppress(RuntimeError): @@ -266,7 +266,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = struct.unpack_from(f">{rank}I", buffer_view, header_size) data_start_offset = header_size + (rank * 4) if data_start_offset < len(buffer_view): - tensor_1d = torch.frombuffer(buffer_view[data_start_offset:], dtype=dtype) + tensor_1d = torch.frombuffer(bytearray(buffer_view[data_start_offset:]), dtype=dtype) return tensor_1d.reshape(shape) return torch.empty(shape, dtype=dtype) @@ -300,7 +300,11 @@ def serialize(self, item: torch.Tensor) -> tuple[bytes, str | None]: def deserialize(self, data: bytes) -> torch.Tensor: assert self._dtype - return torch.frombuffer(data, dtype=self._dtype) if len(data) > 0 else torch.empty((0,), dtype=self._dtype) + return ( + torch.frombuffer(bytearray(data), dtype=self._dtype) + if len(data) > 0 + else torch.empty((0,), dtype=self._dtype) + ) def can_serialize(self, item: torch.Tensor) -> bool: return isinstance(item, torch.Tensor) and len(item.shape) == 1 @@ -333,7 +337,7 @@ def deserialize(self, data: bytes) -> np.ndarray: shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) # deserialize the numpy array bytes - tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype) + tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype).copy() if tensor.shape == shape: return tensor return np.reshape(tensor, shape) @@ -359,7 +363,7 @@ def serialize(self, item: np.ndarray) -> tuple[bytes, str | None]: def deserialize(self, data: bytes) -> np.ndarray: assert self._dtype - return np.frombuffer(data, dtype=self._dtype) + return np.frombuffer(data, dtype=self._dtype).copy() def can_serialize(self, item: np.ndarray) -> bool: return isinstance(item, np.ndarray) and len(item.shape) == 1 diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index cec28af1..9810484e 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -410,3 +410,102 @@ def test_boolean_serializer(): assert not serializer.can_serialize(1) assert not serializer.can_serialize("True") assert not serializer.can_serialize(None) + + +class TestWritableDeserializedArrays: + """Tests that deserialized numpy arrays and tensors are writable. + + This addresses issue #818: np.frombuffer() and torch.frombuffer() with + bytes objects produce non-writable arrays/tensors, which triggers + UserWarning from PyTorch. + """ + + def test_numpy_serializer_deserialize_returns_writable_array(self): + """NumpySerializer.deserialize should return a writable numpy array.""" + serializer = NumpySerializer() + arr = np.ones((3, 4), dtype=np.float32) + data, _ = serializer.serialize(arr) + result = serializer.deserialize(data) + assert result.flags.writeable is True, "Deserialized numpy array should be writable" + + def test_no_header_numpy_serializer_deserialize_returns_writable_array(self): + """NoHeaderNumpySerializer.deserialize should return a writable numpy array.""" + serializer = NoHeaderNumpySerializer() + arr = np.ones((10,), dtype=np.float64) + data, name = serializer.serialize(arr) + serializer.setup(name) + result = serializer.deserialize(data) + assert result.flags.writeable is True, "Deserialized numpy array should be writable" + + def test_tensor_serializer_deserialize_no_non_writable_warning(self): + """TensorSerializer.deserialize should not emit UserWarning about non-writable tensors.""" + import warnings + + serializer = TensorSerializer() + tensor = torch.ones((3, 4), dtype=torch.float32) + data, _ = serializer.serialize(tensor) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = serializer.deserialize(data) + writable_warnings = [ + w + for w in caught + if "non-writable" in str(w.message).lower() or "not writable" in str(w.message).lower() + ] + assert len(writable_warnings) == 0, ( + f"Should not emit non-writable warnings, got: {[str(w.message) for w in writable_warnings]}" + ) + + def test_no_header_tensor_serializer_deserialize_no_non_writable_warning(self): + """NoHeaderTensorSerializer.deserialize should not emit UserWarning about non-writable tensors.""" + import warnings + + serializer = NoHeaderTensorSerializer() + tensor = torch.ones((10,), dtype=torch.float32) + data, name = serializer.serialize(tensor) + serializer.setup(name) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = serializer.deserialize(data) + writable_warnings = [ + w + for w in caught + if "non-writable" in str(w.message).lower() or "not writable" in str(w.message).lower() + ] + assert len(writable_warnings) == 0, ( + f"Should not emit non-writable warnings, got: {[str(w.message) for w in writable_warnings]}" + ) + + def test_numpy_serializer_deserialize_correctness(self): + """Deserialized numpy array values should match the original.""" + serializer = NumpySerializer() + arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + data, _ = serializer.serialize(arr) + result = serializer.deserialize(data) + np.testing.assert_array_equal(result, arr) + + def test_tensor_serializer_deserialize_correctness(self): + """Deserialized tensor values should match the original.""" + serializer = TensorSerializer() + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + data, _ = serializer.serialize(tensor) + result = serializer.deserialize(data) + assert torch.equal(result, tensor) + + def test_no_header_numpy_serializer_deserialize_correctness(self): + """Deserialized no-header numpy array values should match the original.""" + serializer = NoHeaderNumpySerializer() + arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + data, name = serializer.serialize(arr) + serializer.setup(name) + result = serializer.deserialize(data) + np.testing.assert_array_equal(result, arr) + + def test_no_header_tensor_serializer_deserialize_correctness(self): + """Deserialized no-header tensor values should match the original.""" + serializer = NoHeaderTensorSerializer() + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + data, name = serializer.serialize(tensor) + serializer.setup(name) + result = serializer.deserialize(data) + assert torch.equal(result, tensor)