Summary
When StreamingDataset is configured with item_loader=ParquetLoader(low_memory=False) and iterated through StreamingDataLoader(num_workers > 0) on Linux, iteration hangs on the first batch. The DataLoader main thread sits in multiprocessing.connection.poll and the worker subprocesses are stuck inside polars.LazyFrame.collect().
This looks like the documented incompatibility between os.fork() (PyTorch DataLoader's default multiprocessing_context on Linux) and Polars's Rust/Rayon thread pool:
In low_memory=False mode, ParquetLoader._get_item() calls pl.scan_parquet(chunk).collect() inside each DataLoader worker, which is the call pattern the Polars issue describes. Passing multiprocessing_context=mp.get_context("spawn") to StreamingDataLoader resolves it.
Minimal reproduction
Self-contained, no cloud access needed. Reliably reproduces on the versions listed at the bottom.
#!/usr/bin/env python3
"""Reproduce litdata + Polars + fork() DataLoader hang with low_memory=False."""
import multiprocessing as mp
import os
import shutil
import signal
import sys
import time
DATA_DIR = "/tmp/litdata_polars_fork_repro"
TIMEOUT_SEC = 90
def main() -> None:
mode = sys.argv[1] if len(sys.argv) > 1 else "fork"
assert mode in ("fork", "spawn", "forkserver"), f"unknown mode: {mode!r}"
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import polars as pl
import litdata as ld
from litdata.streaming.item_loader import ParquetLoader
print(f"[env] python={sys.version.split()[0]} "
f"default_mp_start_method={mp.get_start_method(allow_none=False)!r} "
f"litdata={ld.__version__} polars={pl.__version__}")
if os.path.exists(DATA_DIR):
shutil.rmtree(DATA_DIR)
os.makedirs(DATA_DIR, exist_ok=True)
rng = np.random.default_rng(0)
n_files, rows_per_file, seq_len = 8, 50_000, 1024
for i in range(n_files):
tbl = pa.table({
"ids": rng.integers(0, 32_000, size=(rows_per_file, seq_len), dtype=np.int32).tolist(),
"labels": rng.integers(0, 2, size=rows_per_file, dtype=np.int64),
"score": rng.random(rows_per_file, dtype=np.float32),
"tag": [f"tag_{j}" for j in range(rows_per_file)],
})
pq.write_table(tbl, os.path.join(DATA_DIR, f"part-{i}.parquet"))
ld.index_parquet_dataset(DATA_DIR)
# Touch Polars in the parent first — mirrors what happens when a validation
# sanity check runs before training workers are forked.
_ = pl.scan_parquet(os.path.join(DATA_DIR, "part-0.parquet")).collect()
dataset = ld.StreamingDataset(
DATA_DIR,
item_loader=ParquetLoader(low_memory=False),
)
loader_kwargs: dict = {"batch_size": 8, "num_workers": 4}
if mode != "fork":
loader_kwargs["multiprocessing_context"] = mp.get_context(mode)
loader = ld.StreamingDataLoader(dataset, **loader_kwargs)
t0 = time.time()
for i, batch in enumerate(loader):
print(f" got batch {i} ({time.time() - t0:.1f}s)")
if i >= 5:
break
print(f"[OK] completed normally in {time.time() - t0:.1f}s")
if __name__ == "__main__":
def _on_timeout(signum, frame):
sys.stderr.write(f"\n[!] TIMED OUT after {TIMEOUT_SEC}s -- this is the hang.\n")
os._exit(124)
signal.signal(signal.SIGALRM, _on_timeout)
signal.alarm(TIMEOUT_SEC)
main()
Observed
Default (fork) — hangs, script's 90 s timeout fires:
$ python /tmp/repro_litdata_polars_fork.py
[env] python=3.11.11 default_mp_start_method='fork' litdata=0.2.61 polars=1.33.1
[+] created 8 parquet files (775.3 MB total) under /tmp/litdata_polars_fork_repro
Indexing progress: 100%|██████████████████████| 8/8 [00:00<00:00, 2637.72step/s]
[+] indexed parquet dataset
[+] warmed up Polars in parent process (initializes Rayon pool)
You have set low_memory=False in ParquetLoader. This may result in high memory usage when processing large Parquet chunk files. Consider setting low_memory=True to reduce memory consumption.
[+] using PyTorch DataLoader default mp context (== 'fork' on this OS)
[+] iterating loader (timeout in 90s)...
[!] TIMED OUT after 90s -- this is the hang.
With multiprocessing_context="spawn" — completes in ~4 s:
$ python /tmp/repro_litdata_polars_fork.py spawn
[env] python=3.11.11 default_mp_start_method='fork' litdata=0.2.61 polars=1.33.1
[+] created 8 parquet files (775.3 MB total) under /tmp/litdata_polars_fork_repro
Indexing progress: 100%|██████████████████████| 8/8 [00:00<00:00, 2686.50step/s]
[+] indexed parquet dataset
[+] warmed up Polars in parent process (initializes Rayon pool)
You have set low_memory=False in ParquetLoader. This may result in high memory usage when processing large Parquet chunk files. Consider setting low_memory=True to reduce memory consumption.
[+] using multiprocessing_context='spawn'
[+] iterating loader (timeout in 90s)...
got batch 0 (2.2s)
got batch 1 (3.1s)
got batch 2 (3.1s)
got batch 3 (3.1s)
got batch 4 (3.1s)
got batch 5 (3.8s)
[OK] completed normally in 4.2s
Stack traces from the original hang
From an 8-rank DDP training job. After ~480 s the PyTorch NCCL watchdog fires on ranks 1–7 (waiting at a gradient AllReduce that rank 0 never reached); rank 0's watchdog is silent because it has no pending NCCL op.
DataLoader main thread on rank 0 (via py-spy dump):
select (selectors.py:415)
wait (multiprocessing/connection.py:948)
poll (multiprocessing/connection.py:257)
get (multiprocessing/queues.py:113)
_try_get_data (torch/utils/data/dataloader.py:1275)
_get_data (torch/utils/data/dataloader.py:1444)
_next_data (torch/utils/data/dataloader.py:1482)
__next__ (torch/utils/data/dataloader.py:732)
__iter__ (litdata/streaming/dataloader.py:675)
__next__ (lightning/pytorch/utilities/combined_loader.py:341)
__next__ (lightning/pytorch/loops/fetchers.py:134)
advance (lightning/pytorch/loops/training_epoch_loop.py:311)
...
fit (lightning/pytorch/trainer/trainer.py:584)
DataLoader worker subprocess (all four workers identical):
collect (polars/lazyframe/frame.py:2407) ← stuck inside Polars
wrapper (polars/lazyframe/opt_flags.py:330)
wrapper (polars/_utils/deprecation.py:97)
_get_item (litdata/streaming/item_loader.py:776) ← pl.scan_parquet(...).collect()
load_item_from_chunk (litdata/streaming/item_loader.py:697)
read (litdata/streaming/reader.py:460)
__getitem__ (litdata/streaming/cache.py:155)
__getitem__ (litdata/streaming/dataset.py:494)
__next__ (litdata/streaming/dataset.py:556)
fetch (torch/utils/data/_utils/fetch.py:33)
_worker_loop (torch/utils/data/_utils/worker.py:349)
Root cause (our reading)
ParquetLoader._get_item() (the low_memory=False path) calls into Polars in each DataLoader worker — see src/litdata/streaming/item_loader.py:
self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
StreamingDataLoader is a thin subclass of torch.utils.data.DataLoader, and on Linux DataLoader defaults multiprocessing_context to fork.
- Polars's Rayon thread pool is initialized in the parent process the first time a Polars API is touched. After
fork(), the worker inherits a pool struct that references dead thread IDs. The first LazyFrame.collect() in the worker dispatches to that pool and hangs.
From the maintainer reply on pola-rs/polars#24162:
"We don't support using Polars in combination with multiprocessing using fork: https://docs.pola.rs/user-guide/misc/multiprocessing/"
And the Polars docs:
"Polars is multithreaded as to provide strong performance out-of-the-box. Thus, it cannot be combined with fork. … One should use spawn, or forkserver, instead. … Using fork as the method, instead of spawn, will cause a dead lock."
In our testing, low_memory=True (the current default) did not reproduce on the minimal script above — that path uses pyarrow.parquet.read_row_group() + pl.from_arrow() and dispatches less Rayon work. A similar hang may still be possible under heavier load.
Workaround
import multiprocessing as mp
from litdata import StreamingDataLoader
loader = StreamingDataLoader(
dataset,
batch_size=...,
num_workers=...,
multiprocessing_context=mp.get_context("spawn"),
)
Environment
- litdata: 0.2.61
- polars: 1.33.1
- torch: 2.9.1+cu129
- pytorch-lightning: 2.6.1
- Python: 3.11.11
- OS: Ubuntu 22.04.5 LTS, Linux 6.8.0-1030-gcp x86_64
- Default
multiprocessing start method: fork
- Hardware: 8-GPU Linux x86_64 node (the minimal repro reproduces with no GPU involvement)
Summary
When
StreamingDatasetis configured withitem_loader=ParquetLoader(low_memory=False)and iterated throughStreamingDataLoader(num_workers > 0)on Linux, iteration hangs on the first batch. The DataLoader main thread sits inmultiprocessing.connection.polland the worker subprocesses are stuck insidepolars.LazyFrame.collect().This looks like the documented incompatibility between
os.fork()(PyTorch DataLoader's defaultmultiprocessing_contexton Linux) and Polars's Rust/Rayon thread pool:fork… Using fork as the method, instead of spawn, will cause a dead lock."In
low_memory=Falsemode,ParquetLoader._get_item()callspl.scan_parquet(chunk).collect()inside each DataLoader worker, which is the call pattern the Polars issue describes. Passingmultiprocessing_context=mp.get_context("spawn")toStreamingDataLoaderresolves it.Minimal reproduction
Self-contained, no cloud access needed. Reliably reproduces on the versions listed at the bottom.
Observed
Default (fork) — hangs, script's 90 s timeout fires:
With
multiprocessing_context="spawn"— completes in ~4 s:Stack traces from the original hang
From an 8-rank DDP training job. After ~480 s the PyTorch NCCL watchdog fires on ranks 1–7 (waiting at a gradient
AllReducethat rank 0 never reached); rank 0's watchdog is silent because it has no pending NCCL op.DataLoader main thread on rank 0 (via
py-spy dump):DataLoader worker subprocess (all four workers identical):
Root cause (our reading)
ParquetLoader._get_item()(thelow_memory=Falsepath) calls into Polars in each DataLoader worker — seesrc/litdata/streaming/item_loader.py:StreamingDataLoaderis a thin subclass oftorch.utils.data.DataLoader, and on LinuxDataLoaderdefaultsmultiprocessing_contexttofork.fork(), the worker inherits a pool struct that references dead thread IDs. The firstLazyFrame.collect()in the worker dispatches to that pool and hangs.From the maintainer reply on pola-rs/polars#24162:
And the Polars docs:
In our testing,
low_memory=True(the current default) did not reproduce on the minimal script above — that path usespyarrow.parquet.read_row_group()+pl.from_arrow()and dispatches less Rayon work. A similar hang may still be possible under heavier load.Workaround
Environment
multiprocessingstart method:fork