diff --git a/docs/training/index.mdx b/docs/training/index.mdx
index 01ac5e6..1ff3581 100644
--- a/docs/training/index.mdx
+++ b/docs/training/index.mdx
@@ -5,11 +5,13 @@ description: "Introduction to loading data, shuffling, and creating permutations
icon: boxes-stacked
---
-LanceDB makes an excellent data backend for training machine learning models. This section will describe the
-`Permutation` API in LanceDB that's designed to facilitate the
-process of training a model and explain how to use LanceDB as a data backend for training.
+LanceDB makes an excellent data backend for training machine learning models. While a `Table` by itself can be treated
+as input to a data loader this is typcially limited. A `Permutation` can be created to control which rows are accessed
+and in what order. For an even more complete solution LanceDB also provides a `StreamingDataset` which adapts the lower
+level `Permutation` API into a simple iterable dataset supporting prefetching, elastic determinism, resumability,
+and multi-threaded transformations.
-## Data loading
+## Basic Data loading
Most model training frameworks iterate through data in batches and feed this data into the model. This process is
often referred to as **data loading**. The simplest way to load data into a model is to iterate a LanceDB table in
@@ -29,116 +31,317 @@ for batch in table:
In practice, this is too simplistic for effective training. We may not want to load all the data, or we may want
to load the data in a different order, or we may need to apply some sort of processing to the data before training.
-To achieve this, we will often want to create a `Permutation` of the table.
+To achieve this, we can use the `StreamingDataset`.
-## Selecting rows
+
+```py Python icon=Python
+from lancedb.streaming import StreamingDataset
+import lancedb
-When training a model, we might not want to load all of the data. For example, we might filter out columns that
-are not needed for training. We might also divide the data into training and validation sets. Or we could divide
-the data into multiple sets for cross-validation.
+db = lancedb.connect("file://some/db/path")
+table = db.open_table("some_table")
-Whenever we create a permutation of the table, we have to first decide which rows we want to include (and in what
-order). This is stored in a **permutation table** which marks out the row ids that make up our data. Other decisions,
-such as which columns to include, and what transformations to apply, can be defined at read time and don't require
-a separate permutation table.
+ds = StreamingDataset(table, shuffle_seed=42)
+for sample in ds:
+ # sample is a plain Python dict, e.g. {"feature": 0.82, "label": "cat"}
+ train_step(sample)
+```
+
-
-Permutation tables are tables, just like any other table in LanceDB. By default, they are
-stored in memory but they can be persisted to storage as well. This is useful when you want to share a permutation
-table across processes or nodes.
-
+## Advanced Data Loading
+
+The `StreamingDataset` wraps a LanceDB `Table` and, by default, simply adds prefetching and transformation from
+Arrow format to Python. However, it can be configured to handle more advanced scenarios. To help understand we
+will consider a model trained with stochastic gradient descent (SGD) and distributed data parallelism (DDP). In
+this example we need to load multiple GPUs, across multiple servers, with batches of data. After each batch is
+processed the GPUs exchange weights and the next batch is loaded. This introduces a number of concepts and we
+will use terms from PyTorch in our examples:
+
+ * World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and
+ each server has 4 GPUs then the world size is 8.
+ * Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is
+ an integer in the range `[0, world_size)`. This is important for data loading because each rank should get its
+ own portion of the data (e.g. rank 3 and rank 4 will see different rows).
+ * Global batch size - The global batch size is the number of rows, across _all_ GPUs, that we process in each step
+ of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load
+ 128 rows onto each GPU for each step. The global batch size must be divisible by the world size.
+ * Batch size - The batch size is the number of rows, for a single GPU, that we process in each step of the SGD
+ algorithm. Once again, if we have 8 GPUs and the global batch size is 1024 then the batch size is 128.
+
+Other concepts (read batch size, num_workers) will be discussed later but are specific to a particular section.
+
+### Prefetching
+
+PyTorch datasets were originally built around in-memory structures like a Pandas DataFrame. When they are iterated
+they yield a single sample at a time. This makes sense for simple in-memory structure but if try and access a
+(potentially remote) database one row at a time the per-call overhead will typically be far too expensive. To work
+around this the `StreamingDataset` fetches data and transforms data in batches. The `read_batch_size` parameter
+controls how many rows we read per call to the underlying `Table`.
+
+In addition to batching up requests to the database the prefetching mechanism will read ahead in the background.
+While the first batch is being transformed and processed by the GPU a `StreamingDataset` will also be reading the
+next batch of data. The `prefetch_batches` parameter controls how many batches of data we will read ahead. This
+should typically be at least 2. A larger value can provide more buffering against jittery workloads but will
+require more RAM.
-## Selecting all rows
+
+```py Python icon=Python
+ds = StreamingDataset(
+ table,
+ shuffle_seed=42,
+ read_batch_size=256, # rows fetched per LanceDB call per split
+ prefetch_batches=8, # batches to keep in flight per split
+)
+```
+
-To select all rows, we can use the `Permutation.identity` method. This gives us a `Permutation` without requiring
-us to create a separate permutation table. This allows us to refine our columns and apply transformations and can
-be useful when the data loader itself is responsible for handling sampling and shuffling.
+### Transformation
-
-```py Python icon=Python
-from lancedb.permutation import Permutation
+Many model training workloads require a transformation step between loading the data and training the model. For
+example, we may need to decode images, tokenize text, or normalize data. A transformation function can be provided
+using the `transform` parameter. Transformation can be expensive and we often want to utilize multiple CPUs to apply
+these transformations. By default transformations will be applied with a `ThreadPoolExecutor` with a number of workers
+equal to the number of CPUs.
-# We can create an identity permutation without needing any separate permutation table.
-permutation = Permutation.identity(table)
+Transformations are applied on batches of data, not individual samples, to allow transformations to amortize per-batch
+overhead. A transformation function will receive an Arrow record batch and should return an iterable of samples (one
+sample per row). The `StreamingDataset` does not care what format these samples take but they should match what your
+data loader expects. For example, the default PyTorch dataloader's collation function can except a variety of different
+sample types, with a python dictionary being one of the most common. The default transformation function converts
+the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself.
-# This allows us to refine our columns and apply transformations
-permutation = permutation.select_columns(["id", "prompt"])
+
+```py Python icon=Python
+import pyarrow as pa
+
+def normalize(batch: pa.RecordBatch) -> list[dict]:
+ # This pure-Python loop holds the GIL and is shown for illustration only.
+ # In practice, prefer a library like torchvision or numpy that releases the
+ # GIL so the ThreadPoolExecutor can run transforms in parallel.
+ rows = batch.to_pylist()
+ for row in rows:
+ row["image"] = [v / 255.0 for v in row["image"]]
+ return rows
+
+ds = StreamingDataset(table, shuffle_seed=42, transform=normalize)
```
-## Filtering rows
+#### Worker Info
+
+The thread-based transformation model that `StreamingDataset` uses by default is only effective when the transform
+function releases the GIL. This is true for most Python scientific libraries (e.g. numpy, pandas, arrow, torchvision)
+but there are some libraries which may not do this. Because of this limitation PyTorch supports launching multiple
+worker processes per rank (the num_workers variable in the data loader). The `StreamingDataset` can handle this
+scenario and will call `get_worker_info` to determine the worker id and the total number of workers and will adjust
+accordingly. However, we find this multiprocessing to be inefficient (adds pickling and transfer overhead as well
+as significantly increasing the amount of RAM required) and suggest starting with `num_workers=1` and only using
+a higher value when you've confirmed an unavoidable GIL bottleneck.
-If we only want to select a subset of rows, then we can use a filter. This will require us to create a permutation
-table which identifies which rows we want to include.
+### Observability & performance
+
+Optimizing data loader performance is tricky because it can be difficult to locate the bottleneck. What is often
+blamed on I/O ends up being a CPU bottleneck in the transform stage (or vice versa). To assist developers the
+`StreamingDataset` offers a number of observability controls. The `raw_queue_depth` can be polled on a regular
+basis to determine the number of rows that have been loaded (I/O finished) but not trasnformed. The
+`prefetch_queue_depth` can be polled to determine the number of rows that have been transformed and are waiting
+to be consumed by the GPU. As long as these queue sizes are non-empty the GPU should be operating at capacity.
+
+If the `prefetch_queue_depth` is consistently zero but the `raw_queue_depth` is not then you have a CPU
+transformation bottleneck. You should investigate GIL bottlenecks or look for ways to optimize your transformation.
+This can often be done by batching the compute work. If both the `prefetch_queue_depth` and `raw_queue_depth` are
+consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to
+reduce the I/O bottleneck.
-```py Python icon=Python
-from lancedb.permutation import Permutation, permutation_builder
+```py Python icon=Python
+import threading, time
-# We can create a permutation table which identifies which rows we want to include.
-permutation_tbl = permutation_builder(table).filter("category = 'cat'").execute()
+ds = StreamingDataset(table, shuffle_seed=42)
-# We can then use this permutation table to create a Permutation object
-permutation = Permutation.from_tables(table, permutation_tbl)
+def log_pipeline_health():
+ while True:
+ print(
+ f"unscanned={ds.unscanned_rows} "
+ f"raw={ds.raw_queue_depth} "
+ f"cooked={ds.prefetch_queue_depth} "
+ f"consumed={ds.consumed_rows}"
+ )
+ time.sleep(1.0)
+
+monitor = threading.Thread(target=log_pipeline_health, daemon=True)
+monitor.start()
+
+for sample in ds:
+ train_step(sample)
+
+print(f"fetch time: {ds.fetch_time:.2f}s transform time: {ds.transform_time:.2f}s")
```
-## Creating splits
+### Filtering data
-LanceDB also provides several different methods for creating splits. These allow us to divide our dataset into
-smaller non-overlapping sets. The split can then be specified when creating the `Permutation` object to view
-only a subset of the data.
+By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also
+supports efficient random access. Reducing the number of columns you load will have a direct impact on I/O
+performance. Reducing the number of rows you load will also have an impact on I/O performance, especially if
+you have a very selective filter, are loading large values, or the data is local or in the LanceDB enterprise
+cache.
+
+You can use the `columns` parameter to specify which columns to load. You can use the `filter` parameter to
+specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row
+ids that match the filter, then divide the matching row ids into splits for loading.
-```py Python icon=Python
-from lancedb.permutation import Permutation, permutation_builder
+```py Python icon=Python
+ds = StreamingDataset(
+ table,
+ shuffle_seed=42,
+ columns=["image", "label"], # skip all other columns
+ filter="category = 'train'", # only training rows
+)
+```
+
+
+### Shuffling rows
-# Here we create two splits, one for training and one for validation. By default, splits have no
-# name and are accessed by index.
-permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05]).execute()
+By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our
+model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model
+to our data. To avoid this, we typically want to shuffle the data before training. This is done by setting the
+`shuffle` parameter to `True`. If this is not set then the data will be divided into splits sequentially.
-# Let's create a permutation object which views only the training data.
-permutation = Permutation.from_tables(table, permutation_tbl, split=0)
+By default, the shuffle seed will be a combination of the provided `shuffle_seed` and the provided `epoch` which
+will ensure that each epoch has a different ordering. If you wish for all epochs to have the same ordering then
+you can set the `epoch` parameter to `0` (it is only used to determine the shuffle seed). If you do not want
+to provide a `shuffle_seed` then you can set it to `None` and a random seed will be used instead.
-# Splits can also be given names. The names can then be used later to access the split instead of
-# requiring us to know the index.
-permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05], split_names=["train", "test"]).execute()
-permutation = Permutation.from_tables(table, permutation_tbl, split="train")
+
+```py Python icon=Python
+# Training loop: each epoch gets a different shuffled ordering
+for epoch in range(num_epochs):
+ ds = StreamingDataset(
+ table,
+ shuffle_seed=42,
+ epoch=epoch, # changes the permutation each epoch
+ )
+ for sample in ds:
+ train_step(sample)
+
+# Evaluation: deterministic sequential order, no shuffle
+eval_ds = StreamingDataset(table, shuffle=False)
+for sample in eval_ds:
+ eval_step(sample)
```
-## Shuffling rows
-
-By default, permutations will access the data in the order the data is stored in the table. This can cause our
-model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model
-to our data. To avoid this, we typically want to shuffle the data before training. Model training frameworks
-(like PyTorch) will often provide a way to shuffle the data. If you are not using one of these frameworks, or if
-you want to shuffle the data with LanceDB, you can shuffle the rows when you create a permutation table.
+
+Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage.
+In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However,
+you can use the `shuffle_clump_size` parameter to shuffle the data in clumps (small contiguous batches that get
+shuffled together). This will give some penalty to the randomness of the shuffle, but will significantly improve
+I/O performance.
+
-```py Python icon=Python
-from lancedb.permutation import Permutation, permutation_builder
+```py Python icon=Python
+# Clumped shuffle: groups of 16 contiguous rows are shuffled together,
+# preserving read locality while still randomising the global ordering.
+ds = StreamingDataset(
+ table,
+ shuffle_seed=42,
+ shuffle_clump_size=16,
+)
+```
+
-# We can shuffle the rows when we create the permutation table.
-permutation_tbl = permutation_builder(table).shuffle().execute()
+### Data splits and elasticity
+
+The data is divided across a number of processes based on the world size and the number of workers per rank.
+These groups are called "splits" and by default the dataset will create `world_size * num_workers` splits.
+This is simple but can lead to problems if you need to rerun the training with a different number of GPUs (i.e.
+different world size). For example, if we have 8 GPUs and 1 worker per rank then we have 8 splits. If we
+later train with 4 GPUs then we will only have 4 splits and the data will be divided differently. This could
+lead to a different model being trained which can make deterministic training harder to reproduce.
+
+To work around this, you can manually specify the number of splits used. This allows you to select a larger
+number of splits than the default and can provide you with a property called "elastic determinism". If we
+consider our example above, if we are going to train on 4 GPUs we can set `num_splits=8` and we will divide
+the data as if we had 8 GPUs. Each rank will be assigned 2 splits and will pull from those two splits in
+a round-robin fashion. This means each global batch that gets generated will be the same as the global batches
+that were generated when we trained with 8 GPUs.
+
+In order for this determinism to work, the number of splits must be a multiple of the world size. Our simple
+example above works because 8 is a multiple of 4. However, what would happen if we trained with 8 GPUs and then
+wanted to train with 6 GPUs. In that case, `num_splits=8` would not be a multiple of 6 and the data would not
+be divided evenly across the ranks. To make this work we can choose a `num_splits` value that is a multiple of
+both 8 and 6. For example, we could choose `num_splits=24`. Good choices for `num_splits` are highly composite
+numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10,
+12, 15, 20, 30, 40, and 60 GPUs).
-# We can then use this permutation table to create a Permutation object, this will now
-# access the data in a random order.
-permutation = Permutation.from_tables(table, permutation_tbl)
+
+```py Python icon=Python
+import torch.distributed as dist
+
+dist.init_process_group("nccl")
+rank = dist.get_rank()
+world_size = dist.get_world_size()
+
+# num_splits=48 is divisible by 1, 2, 3, 4, 6, 8, 12, 16, 24, 48
+# so this dataset works unchanged as you scale up or down GPUs.
+ds = StreamingDataset(
+ table,
+ num_splits=48,
+ shuffle_seed=42,
+ epoch=current_epoch,
+ rank=rank,
+ world_size=world_size,
+)
+
+for sample in ds:
+ train_step(sample)
```
-## Selecting columns
+### Checkpointing and resumability
-By default, permutations will return all columns in the table. If you only need a subset of the columns, you can
-significantly reduce your I/O requirements by selecting only the columns you need. This can be done on the
-permutation object itself, and does not require us to create a separate permutation table.
+Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save
+the model state and resume training from where you left off. Most modern deep learning frameworks support checkpointing
+models. However, we must also be able to checkpoint the data loader so that we can resume training from where we left off.
+To support this the streaming dataset provides the `state_dict` and `load_state_dict` methods so that you can save and
+load the state of the data loader. These methods should be called by your training framework when you want to save or
+load a checkpoint. The `state_dict` method returns a simple python dictionary that can easily be persisted.
-```py Python icon=Python
-from lancedb.permutation import Permutation
+```py Python icon=Python
+import torch
+
+ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size)
+
+for step, sample in enumerate(ds):
+ train_step(sample)
+
+ if step % checkpoint_interval == 0:
+ torch.save(
+ {"model": model.state_dict(), "dataloader": ds.state_dict()},
+ f"checkpoint_{step}.pt",
+ )
-# We can select only the columns we need.
-permutation = Permutation.identity(table).select_columns(["id", "prompt"])
+# --- resuming after a crash ---
+checkpoint = torch.load("checkpoint_100.pt")
+model.load_state_dict(checkpoint["model"])
+
+ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size)
+ds.load_state_dict(checkpoint["dataloader"])
+
+for sample in ds: # continues from step 100, no repeated or skipped rows
+ train_step(sample)
```
-
\ No newline at end of file
+
+
+## Permutations
+
+In some more complicated scenarios, you may want to flexibility of the `StreamingDataset` to shuffle, split, and select
+data while not buying into the full behavior of the iterable dataset. In these cases you can use the `Permutation`
+class. This is a lower level class which the `StreamingDataset` is built on top of. A `Permutation` can be used to
+define a custom ordering of the data. You can then index into the `Permutation` using the `__getitems__` method to
+access rows by their ordering in the `Permutation`. More details on the `Permutation` class can be found in the API
+reference.