Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,10 @@ def load_item_from_chunk(

if self._serializer_name == "no_header_tensor":
# count: number of tokens to read from buffer => `self._block_size`
data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset).clone()
else:
# count: number of tokens to read from buffer => `self._block_size`
data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore
data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset).copy() # type: ignore

return data

Expand Down
5 changes: 0 additions & 5 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import glob
import logging
import os
import warnings
from contextlib import suppress
from datetime import datetime
from queue import Empty, Queue
Expand All @@ -33,9 +32,6 @@
from litdata.utilities.encryption import Encryption
from litdata.utilities.env import _DistributedEnv, _WorkerEnv

warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")


logger = logging.getLogger("litdata.streaming.reader")


Expand Down Expand Up @@ -332,7 +328,6 @@ def __init__(

"""
super().__init__()
warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")

self._cache_dir = cache_dir
self._remote_input_dir = remote_input_dir
Expand Down
14 changes: 9 additions & 5 deletions src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def serialize(self, item: Any) -> tuple[bytes, str | None]:
def deserialize(self, data: bytes) -> torch.Tensor:
from torchvision.io import decode_image, decode_jpeg

array = torch.frombuffer(data, dtype=torch.uint8)
array = torch.frombuffer(bytearray(data), dtype=torch.uint8)
# Try decoding as JPEG. Some datasets (e.g., ImageNet) may have PNG images with a JPEG extension,
# which will cause decode_jpeg to fail. In that case, fall back to a generic image decoder.
with suppress(RuntimeError):
Expand Down Expand Up @@ -266,7 +266,7 @@ def deserialize(self, data: bytes) -> torch.Tensor:
shape = struct.unpack_from(f">{rank}I", buffer_view, header_size)
data_start_offset = header_size + (rank * 4)
if data_start_offset < len(buffer_view):
tensor_1d = torch.frombuffer(buffer_view[data_start_offset:], dtype=dtype)
tensor_1d = torch.frombuffer(bytearray(buffer_view[data_start_offset:]), dtype=dtype)
return tensor_1d.reshape(shape)
return torch.empty(shape, dtype=dtype)

Expand Down Expand Up @@ -300,7 +300,11 @@ def serialize(self, item: torch.Tensor) -> tuple[bytes, str | None]:

def deserialize(self, data: bytes) -> torch.Tensor:
assert self._dtype
return torch.frombuffer(data, dtype=self._dtype) if len(data) > 0 else torch.empty((0,), dtype=self._dtype)
return (
torch.frombuffer(bytearray(data), dtype=self._dtype)
if len(data) > 0
else torch.empty((0,), dtype=self._dtype)
)

def can_serialize(self, item: torch.Tensor) -> bool:
return isinstance(item, torch.Tensor) and len(item.shape) == 1
Expand Down Expand Up @@ -333,7 +337,7 @@ def deserialize(self, data: bytes) -> np.ndarray:
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())

# deserialize the numpy array bytes
tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype)
tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype).copy()
if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)
Expand All @@ -359,7 +363,7 @@ def serialize(self, item: np.ndarray) -> tuple[bytes, str | None]:

def deserialize(self, data: bytes) -> np.ndarray:
assert self._dtype
return np.frombuffer(data, dtype=self._dtype)
return np.frombuffer(data, dtype=self._dtype).copy()

def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and len(item.shape) == 1
Expand Down
99 changes: 99 additions & 0 deletions tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,102 @@ def test_boolean_serializer():
assert not serializer.can_serialize(1)
assert not serializer.can_serialize("True")
assert not serializer.can_serialize(None)


class TestWritableDeserializedArrays:
"""Tests that deserialized numpy arrays and tensors are writable.

This addresses issue #818: np.frombuffer() and torch.frombuffer() with
bytes objects produce non-writable arrays/tensors, which triggers
UserWarning from PyTorch.
"""

def test_numpy_serializer_deserialize_returns_writable_array(self):
"""NumpySerializer.deserialize should return a writable numpy array."""
serializer = NumpySerializer()
arr = np.ones((3, 4), dtype=np.float32)
data, _ = serializer.serialize(arr)
result = serializer.deserialize(data)
assert result.flags.writeable is True, "Deserialized numpy array should be writable"

def test_no_header_numpy_serializer_deserialize_returns_writable_array(self):
"""NoHeaderNumpySerializer.deserialize should return a writable numpy array."""
serializer = NoHeaderNumpySerializer()
arr = np.ones((10,), dtype=np.float64)
data, name = serializer.serialize(arr)
serializer.setup(name)
result = serializer.deserialize(data)
assert result.flags.writeable is True, "Deserialized numpy array should be writable"

def test_tensor_serializer_deserialize_no_non_writable_warning(self):
"""TensorSerializer.deserialize should not emit UserWarning about non-writable tensors."""
import warnings

serializer = TensorSerializer()
tensor = torch.ones((3, 4), dtype=torch.float32)
data, _ = serializer.serialize(tensor)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
result = serializer.deserialize(data)
writable_warnings = [
w
for w in caught
if "non-writable" in str(w.message).lower() or "not writable" in str(w.message).lower()
]
assert len(writable_warnings) == 0, (
f"Should not emit non-writable warnings, got: {[str(w.message) for w in writable_warnings]}"
)

def test_no_header_tensor_serializer_deserialize_no_non_writable_warning(self):
"""NoHeaderTensorSerializer.deserialize should not emit UserWarning about non-writable tensors."""
import warnings

serializer = NoHeaderTensorSerializer()
tensor = torch.ones((10,), dtype=torch.float32)
data, name = serializer.serialize(tensor)
serializer.setup(name)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
result = serializer.deserialize(data)
writable_warnings = [
w
for w in caught
if "non-writable" in str(w.message).lower() or "not writable" in str(w.message).lower()
]
assert len(writable_warnings) == 0, (
f"Should not emit non-writable warnings, got: {[str(w.message) for w in writable_warnings]}"
)

def test_numpy_serializer_deserialize_correctness(self):
"""Deserialized numpy array values should match the original."""
serializer = NumpySerializer()
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
data, _ = serializer.serialize(arr)
result = serializer.deserialize(data)
np.testing.assert_array_equal(result, arr)

def test_tensor_serializer_deserialize_correctness(self):
"""Deserialized tensor values should match the original."""
serializer = TensorSerializer()
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
data, _ = serializer.serialize(tensor)
result = serializer.deserialize(data)
assert torch.equal(result, tensor)

def test_no_header_numpy_serializer_deserialize_correctness(self):
"""Deserialized no-header numpy array values should match the original."""
serializer = NoHeaderNumpySerializer()
arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
data, name = serializer.serialize(arr)
serializer.setup(name)
result = serializer.deserialize(data)
np.testing.assert_array_equal(result, arr)

def test_no_header_tensor_serializer_deserialize_correctness(self):
"""Deserialized no-header tensor values should match the original."""
serializer = NoHeaderTensorSerializer()
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
data, name = serializer.serialize(tensor)
serializer.setup(name)
result = serializer.deserialize(data)
assert torch.equal(result, tensor)