From 2176a90ee17f3ab9a0579376fa81b11debb02c71 Mon Sep 17 00:00:00 2001 From: westonpace Date: Wed, 1 Jul 2026 15:11:53 -0700 Subject: [PATCH 1/4] Document the streaming dataset --- docs/training/index.mdx | 256 ++++++++++++++++++++++------------------ 1 file changed, 144 insertions(+), 112 deletions(-) diff --git a/docs/training/index.mdx b/docs/training/index.mdx index 01ac5e6e..ef68f328 100644 --- a/docs/training/index.mdx +++ b/docs/training/index.mdx @@ -5,11 +5,13 @@ description: "Introduction to loading data, shuffling, and creating permutations icon: boxes-stacked --- -LanceDB makes an excellent data backend for training machine learning models. This section will describe the -`Permutation` API in LanceDB that's designed to facilitate the -process of training a model and explain how to use LanceDB as a data backend for training. +LanceDB makes an excellent data backend for training machine learning models. While a `Table` by itself can be treated +as input to a data loader this is typcially limited. A `Permutation` can be created to control which rows are accessed +and in what order. For an even more complete solution LanceDB also provides a `StreamingDataset` which adapts the lower +level `Permutation` API into a simple iterable dataset supporting prefetching, elastic determinism, resumability, +and multi-threaded transformations. -## Data loading +## Basic Data loading Most model training frameworks iterate through data in batches and feed this data into the model. This process is often referred to as **data loading**. The simplest way to load data into a model is to iterate a LanceDB table in @@ -29,116 +31,146 @@ for batch in table: In practice, this is too simplistic for effective training. We may not want to load all the data, or we may want to load the data in a different order, or we may need to apply some sort of processing to the data before training. -To achieve this, we will often want to create a `Permutation` of the table. - -## Selecting rows - -When training a model, we might not want to load all of the data. For example, we might filter out columns that -are not needed for training. We might also divide the data into training and validation sets. Or we could divide -the data into multiple sets for cross-validation. +To achieve this, we can use the `StreamingDataset`. + +## Advanced Data Loading + +The `StreamingDataset` wraps a LanceDB `Table` and, by default, simply adds prefetching and transformation from +Arrow format to Python. However, it can be configured to handle more advanced scenarios. To help understand we +will consider a model trained with stochastic gradient descent (SGD) and distributed data parallelism (DDP). In +this example we need to load multiple GPUs, across multiple servers, with batches of data. After each batch is +processed the GPUs exchange weights and the next batch is loaded. This introduces a number of concepts and we +will use terms from PyTorch in our examples: + + * World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and + each server has 4 GPUs then the world size is 8. + * Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is + an integer in the range `[0, world_size]`. This is important for data loading because each rank should get its + own portion of the data (e.g. rank 3 and rank 4 will see different rows). + * Global batch size - The global batch size is the number of rows, across _all_ GPUs, that we process in each step + of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load + 128 rows onto each GPU for each step. The global batch size must be divisible by the world size. + * Batch size - The batch size is the number of rows, for a single GPU, that we process in each step of the SGD + algorithm. Once again, if we have 8 GPUs and the global batch size is 1024 then the batch size is 128. + +Other concepts (read batch size, num_workers) will be discussed later but are specific to a particular section. + +### Prefetching + +PyTorch datasets were originally built around in-memory structures like a Pandas DataFrame. When they are iterated +they yield a single sample at a time. This makes sense for simple in-memory structure but if try and access a +(potentially remote) database one row at a time the per-call overhead will typically be far too expensive. To work +around this the `StreamingDataset` fetches data and transforms data in batches. The `read_batch_size` parameter +controls how many rows we read per call to the underlying `Table`. + +In addition to batching up requests to the database the prefetching mechanism will read ahead in the background. +While the first batch is being transformed and processed by the GPU a `StreamingDataset` will also be reading the +next batch of data. The `prefetch_batches` parameter controls how many batches of data we will read ahead. This +should typically be at least 2. A larger value can provide more buffering against jittery workloads but will +require more RAM. + +### Transformation + +Many model training workloads require a transformation step between loading the data and training the model. For +example, we may need to decode images, tokenize text, or normalize data. A transformation function can be provided +using the `transform` parameter. Transformation can be expensive and we often want to utilize multiple CPUs to apply +these transformations. By default transformations will be applied with a `ThreadPoolExecutor` with a number of workers +equal to the number of CPUs. + +Transformations are applied on batches of data, not individual samples, to allow transformations to amortize per-batch +overhead. A transformation function will receive an Arrow record batch and should return an iterable of samples (one +sample per row). The `StreamingDataset` does not care what format these samples take but they should match what your +data loader expects. For example, the default PyTorch dataloader's collation function can except a variety of different +sample types, with a python dictionary being one of the most common. The default transformation function converts +the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself. + +#### Worker Info + +The thread-based transformation model that `StreamingDataset` uses by default is only effective when the transform +function releases the GIL. This is true for most Python scientific libraries (e.g. numpy, pandas, arrow, torchvision) +but there are some libraries which may not do this. Because of this limitation PyTorch supports launching multiple +worker processes per rank (the num_workers variable in the data loader). The `StreamingDataset` can handle this +scenario and will call `get_worker_info` to determine the worker id and the total number of workers and will adjust +accordingly. However, we find this multiprocessing to be inefficient (adds pickling and transfer overhead as well +as significantly increasing the amount of RAM required) and suggest starting with `num_workers=1` and only using +a higher value when you've confirmed an unavoidable GIL bottleneck. + +### Observability & performance + +Optimizing data loader performance is tricky because it can be difficult to locate the bottleneck. What is often +blamed on I/O ends up being a CPU bottleneck in the transform stage (or vice versa). To assist developers the +`StreamingDataset` offers a number of observability controls. The `raw_queue_depth` can be polled on a regular +basis to determine the number of rows that have been loaded (I/O finished) but not trasnformed. The +`prefetch_queue_depth` can be polled to determine the number of rows that have been transformed and are waiting +to be consumed by the GPU. As long as these queue sizes are non-empty the GPU should be operating at capacity. + +If the `prefetch_queue_depth` is consistently zero but the `raw_queue_depth` is not then you have a CPU +transformation bottleneck. You should investigate GIL bottlenecks or look for ways to optimize your transformation. +This can often be done by batching the compute work. If both the `prefetch_queue_depth` and `raw_queue_depth` are +consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to +reduce the I/O bottleneck. + +### Filtering data + +By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also +supports efficient random access. Reducing the number of columns you load will have a direct impact on I/O +performance. Reducing the number of rows you load will also have an impact on I/O performance, especially if +you have a very selective filter, are loading large values, or the data is local or in the LanceDB enterprise +cache. + +You can use the `columns` parameter to specify which columns to load. You can use the `filter` parameter to +specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row +ids that match the filter, then divide the matching row ids into splits for loading. + +### Shuffling rows + +By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our +model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model +to our data. To avoid this, we typically want to shuffle the data before training. This is done by setting the +`shuffle` parameter to `True`. If this is not set then the data will be divided into splits sequentially. -Whenever we create a permutation of the table, we have to first decide which rows we want to include (and in what -order). This is stored in a **permutation table** which marks out the row ids that make up our data. Other decisions, -such as which columns to include, and what transformations to apply, can be defined at read time and don't require -a separate permutation table. +By default, the shuffle seed will be a combination of the provided `shuffle_seed` and the provided `epoch` which +will ensure that each epoch has a different ordering. If you wish for all epochs to have the same ordering then +you can set the `epoch` parameter to `0` (it is only used to determine the shuffle seed). If you do not want +to provide a `shuffle_seed` then you can set it to `None` and a random seed will be used instead. -Permutation tables are tables, just like any other table in LanceDB. By default, they are -stored in memory but they can be persisted to storage as well. This is useful when you want to share a permutation -table across processes or nodes. +Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage. +In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However, +you can use the `shuffle_clump_size` parameter to shuffle the data in clumps (small contiguous batches that get +shuffled together). This will give some penalty to the randomness of the shuffle, but will significantly improve +I/O performance. -## Selecting all rows - -To select all rows, we can use the `Permutation.identity` method. This gives us a `Permutation` without requiring -us to create a separate permutation table. This allows us to refine our columns and apply transformations and can -be useful when the data loader itself is responsible for handling sampling and shuffling. - - -```py Python icon=Python -from lancedb.permutation import Permutation - -# We can create an identity permutation without needing any separate permutation table. -permutation = Permutation.identity(table) - -# This allows us to refine our columns and apply transformations -permutation = permutation.select_columns(["id", "prompt"]) -``` - - -## Filtering rows - -If we only want to select a subset of rows, then we can use a filter. This will require us to create a permutation -table which identifies which rows we want to include. - - -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder - -# We can create a permutation table which identifies which rows we want to include. -permutation_tbl = permutation_builder(table).filter("category = 'cat'").execute() - -# We can then use this permutation table to create a Permutation object -permutation = Permutation.from_tables(table, permutation_tbl) -``` - - -## Creating splits - -LanceDB also provides several different methods for creating splits. These allow us to divide our dataset into -smaller non-overlapping sets. The split can then be specified when creating the `Permutation` object to view -only a subset of the data. - - -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder - -# Here we create two splits, one for training and one for validation. By default, splits have no -# name and are accessed by index. -permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05]).execute() - -# Let's create a permutation object which views only the training data. -permutation = Permutation.from_tables(table, permutation_tbl, split=0) - -# Splits can also be given names. The names can then be used later to access the split instead of -# requiring us to know the index. -permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05], split_names=["train", "test"]).execute() -permutation = Permutation.from_tables(table, permutation_tbl, split="train") -``` - - -## Shuffling rows - -By default, permutations will access the data in the order the data is stored in the table. This can cause our -model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model -to our data. To avoid this, we typically want to shuffle the data before training. Model training frameworks -(like PyTorch) will often provide a way to shuffle the data. If you are not using one of these frameworks, or if -you want to shuffle the data with LanceDB, you can shuffle the rows when you create a permutation table. - - -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder - -# We can shuffle the rows when we create the permutation table. -permutation_tbl = permutation_builder(table).shuffle().execute() - -# We can then use this permutation table to create a Permutation object, this will now -# access the data in a random order. -permutation = Permutation.from_tables(table, permutation_tbl) -``` - - -## Selecting columns - -By default, permutations will return all columns in the table. If you only need a subset of the columns, you can -significantly reduce your I/O requirements by selecting only the columns you need. This can be done on the -permutation object itself, and does not require us to create a separate permutation table. - - -```py Python icon=Python -from lancedb.permutation import Permutation - -# We can select only the columns we need. -permutation = Permutation.identity(table).select_columns(["id", "prompt"]) -``` - \ No newline at end of file +### Data splits and elasticity + +The data is divided across a number of processes based on the world size and the number of workers per rank. +These groups are called "splits" and by default the dataset will create `world_size * num_workers` splits. +This is simple but can lead to problems if you need to rerun the training with a different number of GPUs (i.e. +different world size). For example, if we have 8 GPUs and 1 worker per rank then we have 8 splits. If we +later train with 4 GPUs then we will only have 4 splits and the data will be divided differently. This could +lead to a different model being trained which can make deterministic training harder to reproduce. + +To work around this, you can manually specify the number of splits used. This allows you to select a larger +number of splits than the default and can provide you with a property called "elastic determinism". If we +consider our example above, if we are going to train on 4 GPUs we can set `num_splits=8` and we will divide +the data as if we had 8 GPUs. Each rank will be assigned 2 splits and will pull from those two splits in +a round-robin fashion. This means each global batch that gets generated will be the same as the global batches +that were generated when we trained with 8 GPUs. + +In order for this determinism to work, the number of splits must be a multiple of the world size. Our simple +example above works because 8 is a multiple of 4. However, what would happen if we trained with 8 GPUs and then +wanted to train with 6 GPUs. In that case, `num_splits=8` would not be a multiple of 6 and the data would not +be divided evenly across the ranks. To make this work we can choose a `num_splits` value that is a multiple of +both 8 and 6. For example, we could choose `num_splits=24`. Good choices for `num_splits` are highly composite +numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10, +12, 15, 20, 30, 40, and 60 GPUs). + +### Checkpointing and resumability + +Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save +the model state and resume training from where you left off. Most modern deep learning frameworks support checkpointing +models. However, we must also be able to checkpoint the data loader so that we can resume training from where we left off. +To support this the streaming dataset provides the `state_dict` and `load_state_dict` methods so that you can save and +load the state of the data loader. These methods should be called by your training framework when you want to save or +load a checkpoint. The `state_dict` method returns a simple python dictionary that can easily be persisted. From c78afa8f781eb22b1caf4747bd533ac4799eda5e Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 2 Jul 2026 07:11:05 -0700 Subject: [PATCH 2/4] Also mention the permutation API for advanced cases --- docs/training/index.mdx | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/training/index.mdx b/docs/training/index.mdx index ef68f328..59158862 100644 --- a/docs/training/index.mdx +++ b/docs/training/index.mdx @@ -174,3 +174,12 @@ models. However, we must also be able to checkpoint the data loader so that we To support this the streaming dataset provides the `state_dict` and `load_state_dict` methods so that you can save and load the state of the data loader. These methods should be called by your training framework when you want to save or load a checkpoint. The `state_dict` method returns a simple python dictionary that can easily be persisted. + +## Permutations + +In some more complicated scenarios, you may want to flexibility of the `StreamingDataset` to shuffle, split, and select +data while not buying into the full behavior of the iterable dataset. In these cases you can use the `Permutation` +class. This is a lower level class which the `StreamingDataset` is built on top of. A `Permutation` can be used to +define a custom ordering of the data. You can then index into the `Permutation` using the `__getitems__` method to +access rows by their ordering in the `Permutation`. More details on the `Permutation` class can be found in the API +reference. From f1ec1aacc108d86adaf00b0a514c2d3340f0ab33 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 2 Jul 2026 07:39:46 -0700 Subject: [PATCH 3/4] Add examples --- docs/training/index.mdx | 162 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 1 deletion(-) diff --git a/docs/training/index.mdx b/docs/training/index.mdx index 59158862..a0ecafe4 100644 --- a/docs/training/index.mdx +++ b/docs/training/index.mdx @@ -33,6 +33,21 @@ In practice, this is too simplistic for effective training. We may not want to l to load the data in a different order, or we may need to apply some sort of processing to the data before training. To achieve this, we can use the `StreamingDataset`. + +```py Python icon=Python +from lancedb.streaming import StreamingDataset +import lancedb + +db = lancedb.connect("file://some/db/path") +table = db.open_table("some_table") + +ds = StreamingDataset(table, shuffle_seed=42) +for sample in ds: + # sample is a plain Python dict, e.g. {"feature": 0.82, "label": "cat"} + train_step(sample) +``` + + ## Advanced Data Loading The `StreamingDataset` wraps a LanceDB `Table` and, by default, simply adds prefetching and transformation from @@ -45,7 +60,7 @@ will use terms from PyTorch in our examples: * World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and each server has 4 GPUs then the world size is 8. * Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is - an integer in the range `[0, world_size]`. This is important for data loading because each rank should get its + an integer in the range `[0, world_size)`. This is important for data loading because each rank should get its own portion of the data (e.g. rank 3 and rank 4 will see different rows). * Global batch size - The global batch size is the number of rows, across _all_ GPUs, that we process in each step of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load @@ -69,6 +84,17 @@ next batch of data. The `prefetch_batches` parameter controls how many batches should typically be at least 2. A larger value can provide more buffering against jittery workloads but will require more RAM. + +```py Python icon=Python +ds = StreamingDataset( + table, + shuffle_seed=42, + read_batch_size=256, # rows fetched per LanceDB call per split + prefetch_batches=8, # batches to keep in flight per split +) +``` + + ### Transformation Many model training workloads require a transformation step between loading the data and training the model. For @@ -84,6 +110,21 @@ data loader expects. For example, the default PyTorch dataloader's collation fu sample types, with a python dictionary being one of the most common. The default transformation function converts the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself. + +```py Python icon=Python +import pyarrow as pa + +def normalize(batch: pa.RecordBatch) -> list[dict]: + """Scale pixel values from [0, 255] to [0.0, 1.0].""" + rows = batch.to_pylist() + for row in rows: + row["image"] = [v / 255.0 for v in row["image"]] + return rows + +ds = StreamingDataset(table, shuffle_seed=42, transform=normalize) +``` + + #### Worker Info The thread-based transformation model that `StreamingDataset` uses by default is only effective when the transform @@ -110,6 +151,32 @@ This can often be done by batching the compute work. If both the `prefetch_queu consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to reduce the I/O bottleneck. + +```py Python icon=Python +import threading, time + +ds = StreamingDataset(table, shuffle_seed=42) + +def log_pipeline_health(): + while True: + print( + f"unscanned={ds.unscanned_rows} " + f"raw={ds.raw_queue_depth} " + f"cooked={ds.prefetch_queue_depth} " + f"consumed={ds.consumed_rows}" + ) + time.sleep(1.0) + +monitor = threading.Thread(target=log_pipeline_health, daemon=True) +monitor.start() + +for sample in ds: + train_step(sample) + +print(f"fetch time: {ds.fetch_time:.2f}s transform time: {ds.transform_time:.2f}s") +``` + + ### Filtering data By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also @@ -122,6 +189,17 @@ You can use the `columns` parameter to specify which columns to load. You can u specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row ids that match the filter, then divide the matching row ids into splits for loading. + +```py Python icon=Python +ds = StreamingDataset( + table, + shuffle_seed=42, + columns=["image", "label"], # skip all other columns + filter="split = 'train'", # only training rows +) +``` + + ### Shuffling rows By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our @@ -134,6 +212,25 @@ will ensure that each epoch has a different ordering. If you wish for all epoch you can set the `epoch` parameter to `0` (it is only used to determine the shuffle seed). If you do not want to provide a `shuffle_seed` then you can set it to `None` and a random seed will be used instead. + +```py Python icon=Python +# Training loop: each epoch gets a different shuffled ordering +for epoch in range(num_epochs): + ds = StreamingDataset( + table, + shuffle_seed=42, + epoch=epoch, # changes the permutation each epoch + ) + for sample in ds: + train_step(sample) + +# Evaluation: deterministic sequential order, no shuffle +eval_ds = StreamingDataset(table, shuffle=False) +for sample in eval_ds: + eval_step(sample) +``` + + Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage. In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However, @@ -142,6 +239,18 @@ shuffled together). This will give some penalty to the randomness of the shuffl I/O performance. + +```py Python icon=Python +# Clumped shuffle: groups of 16 contiguous rows are shuffled together, +# preserving read locality while still randomising the global ordering. +ds = StreamingDataset( + table, + shuffle_seed=42, + shuffle_clump_size=16, +) +``` + + ### Data splits and elasticity The data is divided across a number of processes based on the world size and the number of workers per rank. @@ -166,6 +275,30 @@ both 8 and 6. For example, we could choose `num_splits=24`. Good choices for ` numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 40, and 60 GPUs). + +```py Python icon=Python +import torch.distributed as dist + +dist.init_process_group("nccl") +rank = dist.get_rank() +world_size = dist.get_world_size() + +# num_splits=48 is divisible by 1, 2, 3, 4, 6, 8, 12, 16, 24, 48 +# so this dataset works unchanged as you scale up or down GPUs. +ds = StreamingDataset( + table, + num_splits=48, + shuffle_seed=42, + epoch=current_epoch, + rank=rank, + world_size=world_size, +) + +for sample in ds: + train_step(sample) +``` + + ### Checkpointing and resumability Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save @@ -175,6 +308,33 @@ To support this the streaming dataset provides the `state_dict` and `load_state_ load the state of the data loader. These methods should be called by your training framework when you want to save or load a checkpoint. The `state_dict` method returns a simple python dictionary that can easily be persisted. + +```py Python icon=Python +import torch + +ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size) + +for step, sample in enumerate(ds): + train_step(sample) + + if step % checkpoint_interval == 0: + torch.save( + {"model": model.state_dict(), "dataloader": ds.state_dict()}, + f"checkpoint_{step}.pt", + ) + +# --- resuming after a crash --- +checkpoint = torch.load("checkpoint_100.pt") +model.load_state_dict(checkpoint["model"]) + +ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size) +ds.load_state_dict(checkpoint["dataloader"]) + +for sample in ds: # continues from step 100, no repeated or skipped rows + train_step(sample) +``` + + ## Permutations In some more complicated scenarios, you may want to flexibility of the `StreamingDataset` to shuffle, split, and select From 571b5b2a15f395736c76441da04442e8f831c190 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 2 Jul 2026 08:02:01 -0700 Subject: [PATCH 4/4] Minor example updates --- docs/training/index.mdx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/training/index.mdx b/docs/training/index.mdx index a0ecafe4..1ff35813 100644 --- a/docs/training/index.mdx +++ b/docs/training/index.mdx @@ -115,7 +115,9 @@ the Arrow record batch into an iterable of python dictionaries without doing any import pyarrow as pa def normalize(batch: pa.RecordBatch) -> list[dict]: - """Scale pixel values from [0, 255] to [0.0, 1.0].""" + # This pure-Python loop holds the GIL and is shown for illustration only. + # In practice, prefer a library like torchvision or numpy that releases the + # GIL so the ThreadPoolExecutor can run transforms in parallel. rows = batch.to_pylist() for row in rows: row["image"] = [v / 255.0 for v in row["image"]] @@ -194,8 +196,8 @@ ids that match the filter, then divide the matching row ids into splits for load ds = StreamingDataset( table, shuffle_seed=42, - columns=["image", "label"], # skip all other columns - filter="split = 'train'", # only training rows + columns=["image", "label"], # skip all other columns + filter="category = 'train'", # only training rows ) ```