diff --git a/docs/training/index.mdx b/docs/training/index.mdx index 01ac5e6..1ff3581 100644 --- a/docs/training/index.mdx +++ b/docs/training/index.mdx @@ -5,11 +5,13 @@ description: "Introduction to loading data, shuffling, and creating permutations icon: boxes-stacked --- -LanceDB makes an excellent data backend for training machine learning models. This section will describe the -`Permutation` API in LanceDB that's designed to facilitate the -process of training a model and explain how to use LanceDB as a data backend for training. +LanceDB makes an excellent data backend for training machine learning models. While a `Table` by itself can be treated +as input to a data loader this is typcially limited. A `Permutation` can be created to control which rows are accessed +and in what order. For an even more complete solution LanceDB also provides a `StreamingDataset` which adapts the lower +level `Permutation` API into a simple iterable dataset supporting prefetching, elastic determinism, resumability, +and multi-threaded transformations. -## Data loading +## Basic Data loading Most model training frameworks iterate through data in batches and feed this data into the model. This process is often referred to as **data loading**. The simplest way to load data into a model is to iterate a LanceDB table in @@ -29,116 +31,317 @@ for batch in table: In practice, this is too simplistic for effective training. We may not want to load all the data, or we may want to load the data in a different order, or we may need to apply some sort of processing to the data before training. -To achieve this, we will often want to create a `Permutation` of the table. +To achieve this, we can use the `StreamingDataset`. -## Selecting rows + +```py Python icon=Python +from lancedb.streaming import StreamingDataset +import lancedb -When training a model, we might not want to load all of the data. For example, we might filter out columns that -are not needed for training. We might also divide the data into training and validation sets. Or we could divide -the data into multiple sets for cross-validation. +db = lancedb.connect("file://some/db/path") +table = db.open_table("some_table") -Whenever we create a permutation of the table, we have to first decide which rows we want to include (and in what -order). This is stored in a **permutation table** which marks out the row ids that make up our data. Other decisions, -such as which columns to include, and what transformations to apply, can be defined at read time and don't require -a separate permutation table. +ds = StreamingDataset(table, shuffle_seed=42) +for sample in ds: + # sample is a plain Python dict, e.g. {"feature": 0.82, "label": "cat"} + train_step(sample) +``` + - -Permutation tables are tables, just like any other table in LanceDB. By default, they are -stored in memory but they can be persisted to storage as well. This is useful when you want to share a permutation -table across processes or nodes. - +## Advanced Data Loading + +The `StreamingDataset` wraps a LanceDB `Table` and, by default, simply adds prefetching and transformation from +Arrow format to Python. However, it can be configured to handle more advanced scenarios. To help understand we +will consider a model trained with stochastic gradient descent (SGD) and distributed data parallelism (DDP). In +this example we need to load multiple GPUs, across multiple servers, with batches of data. After each batch is +processed the GPUs exchange weights and the next batch is loaded. This introduces a number of concepts and we +will use terms from PyTorch in our examples: + + * World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and + each server has 4 GPUs then the world size is 8. + * Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is + an integer in the range `[0, world_size)`. This is important for data loading because each rank should get its + own portion of the data (e.g. rank 3 and rank 4 will see different rows). + * Global batch size - The global batch size is the number of rows, across _all_ GPUs, that we process in each step + of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load + 128 rows onto each GPU for each step. The global batch size must be divisible by the world size. + * Batch size - The batch size is the number of rows, for a single GPU, that we process in each step of the SGD + algorithm. Once again, if we have 8 GPUs and the global batch size is 1024 then the batch size is 128. + +Other concepts (read batch size, num_workers) will be discussed later but are specific to a particular section. + +### Prefetching + +PyTorch datasets were originally built around in-memory structures like a Pandas DataFrame. When they are iterated +they yield a single sample at a time. This makes sense for simple in-memory structure but if try and access a +(potentially remote) database one row at a time the per-call overhead will typically be far too expensive. To work +around this the `StreamingDataset` fetches data and transforms data in batches. The `read_batch_size` parameter +controls how many rows we read per call to the underlying `Table`. + +In addition to batching up requests to the database the prefetching mechanism will read ahead in the background. +While the first batch is being transformed and processed by the GPU a `StreamingDataset` will also be reading the +next batch of data. The `prefetch_batches` parameter controls how many batches of data we will read ahead. This +should typically be at least 2. A larger value can provide more buffering against jittery workloads but will +require more RAM. -## Selecting all rows + +```py Python icon=Python +ds = StreamingDataset( + table, + shuffle_seed=42, + read_batch_size=256, # rows fetched per LanceDB call per split + prefetch_batches=8, # batches to keep in flight per split +) +``` + -To select all rows, we can use the `Permutation.identity` method. This gives us a `Permutation` without requiring -us to create a separate permutation table. This allows us to refine our columns and apply transformations and can -be useful when the data loader itself is responsible for handling sampling and shuffling. +### Transformation - -```py Python icon=Python -from lancedb.permutation import Permutation +Many model training workloads require a transformation step between loading the data and training the model. For +example, we may need to decode images, tokenize text, or normalize data. A transformation function can be provided +using the `transform` parameter. Transformation can be expensive and we often want to utilize multiple CPUs to apply +these transformations. By default transformations will be applied with a `ThreadPoolExecutor` with a number of workers +equal to the number of CPUs. -# We can create an identity permutation without needing any separate permutation table. -permutation = Permutation.identity(table) +Transformations are applied on batches of data, not individual samples, to allow transformations to amortize per-batch +overhead. A transformation function will receive an Arrow record batch and should return an iterable of samples (one +sample per row). The `StreamingDataset` does not care what format these samples take but they should match what your +data loader expects. For example, the default PyTorch dataloader's collation function can except a variety of different +sample types, with a python dictionary being one of the most common. The default transformation function converts +the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself. -# This allows us to refine our columns and apply transformations -permutation = permutation.select_columns(["id", "prompt"]) + +```py Python icon=Python +import pyarrow as pa + +def normalize(batch: pa.RecordBatch) -> list[dict]: + # This pure-Python loop holds the GIL and is shown for illustration only. + # In practice, prefer a library like torchvision or numpy that releases the + # GIL so the ThreadPoolExecutor can run transforms in parallel. + rows = batch.to_pylist() + for row in rows: + row["image"] = [v / 255.0 for v in row["image"]] + return rows + +ds = StreamingDataset(table, shuffle_seed=42, transform=normalize) ``` -## Filtering rows +#### Worker Info + +The thread-based transformation model that `StreamingDataset` uses by default is only effective when the transform +function releases the GIL. This is true for most Python scientific libraries (e.g. numpy, pandas, arrow, torchvision) +but there are some libraries which may not do this. Because of this limitation PyTorch supports launching multiple +worker processes per rank (the num_workers variable in the data loader). The `StreamingDataset` can handle this +scenario and will call `get_worker_info` to determine the worker id and the total number of workers and will adjust +accordingly. However, we find this multiprocessing to be inefficient (adds pickling and transfer overhead as well +as significantly increasing the amount of RAM required) and suggest starting with `num_workers=1` and only using +a higher value when you've confirmed an unavoidable GIL bottleneck. -If we only want to select a subset of rows, then we can use a filter. This will require us to create a permutation -table which identifies which rows we want to include. +### Observability & performance + +Optimizing data loader performance is tricky because it can be difficult to locate the bottleneck. What is often +blamed on I/O ends up being a CPU bottleneck in the transform stage (or vice versa). To assist developers the +`StreamingDataset` offers a number of observability controls. The `raw_queue_depth` can be polled on a regular +basis to determine the number of rows that have been loaded (I/O finished) but not trasnformed. The +`prefetch_queue_depth` can be polled to determine the number of rows that have been transformed and are waiting +to be consumed by the GPU. As long as these queue sizes are non-empty the GPU should be operating at capacity. + +If the `prefetch_queue_depth` is consistently zero but the `raw_queue_depth` is not then you have a CPU +transformation bottleneck. You should investigate GIL bottlenecks or look for ways to optimize your transformation. +This can often be done by batching the compute work. If both the `prefetch_queue_depth` and `raw_queue_depth` are +consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to +reduce the I/O bottleneck. -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder +```py Python icon=Python +import threading, time -# We can create a permutation table which identifies which rows we want to include. -permutation_tbl = permutation_builder(table).filter("category = 'cat'").execute() +ds = StreamingDataset(table, shuffle_seed=42) -# We can then use this permutation table to create a Permutation object -permutation = Permutation.from_tables(table, permutation_tbl) +def log_pipeline_health(): + while True: + print( + f"unscanned={ds.unscanned_rows} " + f"raw={ds.raw_queue_depth} " + f"cooked={ds.prefetch_queue_depth} " + f"consumed={ds.consumed_rows}" + ) + time.sleep(1.0) + +monitor = threading.Thread(target=log_pipeline_health, daemon=True) +monitor.start() + +for sample in ds: + train_step(sample) + +print(f"fetch time: {ds.fetch_time:.2f}s transform time: {ds.transform_time:.2f}s") ``` -## Creating splits +### Filtering data -LanceDB also provides several different methods for creating splits. These allow us to divide our dataset into -smaller non-overlapping sets. The split can then be specified when creating the `Permutation` object to view -only a subset of the data. +By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also +supports efficient random access. Reducing the number of columns you load will have a direct impact on I/O +performance. Reducing the number of rows you load will also have an impact on I/O performance, especially if +you have a very selective filter, are loading large values, or the data is local or in the LanceDB enterprise +cache. + +You can use the `columns` parameter to specify which columns to load. You can use the `filter` parameter to +specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row +ids that match the filter, then divide the matching row ids into splits for loading. -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder +```py Python icon=Python +ds = StreamingDataset( + table, + shuffle_seed=42, + columns=["image", "label"], # skip all other columns + filter="category = 'train'", # only training rows +) +``` + + +### Shuffling rows -# Here we create two splits, one for training and one for validation. By default, splits have no -# name and are accessed by index. -permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05]).execute() +By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our +model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model +to our data. To avoid this, we typically want to shuffle the data before training. This is done by setting the +`shuffle` parameter to `True`. If this is not set then the data will be divided into splits sequentially. -# Let's create a permutation object which views only the training data. -permutation = Permutation.from_tables(table, permutation_tbl, split=0) +By default, the shuffle seed will be a combination of the provided `shuffle_seed` and the provided `epoch` which +will ensure that each epoch has a different ordering. If you wish for all epochs to have the same ordering then +you can set the `epoch` parameter to `0` (it is only used to determine the shuffle seed). If you do not want +to provide a `shuffle_seed` then you can set it to `None` and a random seed will be used instead. -# Splits can also be given names. The names can then be used later to access the split instead of -# requiring us to know the index. -permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05], split_names=["train", "test"]).execute() -permutation = Permutation.from_tables(table, permutation_tbl, split="train") + +```py Python icon=Python +# Training loop: each epoch gets a different shuffled ordering +for epoch in range(num_epochs): + ds = StreamingDataset( + table, + shuffle_seed=42, + epoch=epoch, # changes the permutation each epoch + ) + for sample in ds: + train_step(sample) + +# Evaluation: deterministic sequential order, no shuffle +eval_ds = StreamingDataset(table, shuffle=False) +for sample in eval_ds: + eval_step(sample) ``` -## Shuffling rows - -By default, permutations will access the data in the order the data is stored in the table. This can cause our -model to learn artifacts specific to the order of the data. This is one of many ways we can "overfit" our model -to our data. To avoid this, we typically want to shuffle the data before training. Model training frameworks -(like PyTorch) will often provide a way to shuffle the data. If you are not using one of these frameworks, or if -you want to shuffle the data with LanceDB, you can shuffle the rows when you create a permutation table. + +Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage. +In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However, +you can use the `shuffle_clump_size` parameter to shuffle the data in clumps (small contiguous batches that get +shuffled together). This will give some penalty to the randomness of the shuffle, but will significantly improve +I/O performance. + -```py Python icon=Python -from lancedb.permutation import Permutation, permutation_builder +```py Python icon=Python +# Clumped shuffle: groups of 16 contiguous rows are shuffled together, +# preserving read locality while still randomising the global ordering. +ds = StreamingDataset( + table, + shuffle_seed=42, + shuffle_clump_size=16, +) +``` + -# We can shuffle the rows when we create the permutation table. -permutation_tbl = permutation_builder(table).shuffle().execute() +### Data splits and elasticity + +The data is divided across a number of processes based on the world size and the number of workers per rank. +These groups are called "splits" and by default the dataset will create `world_size * num_workers` splits. +This is simple but can lead to problems if you need to rerun the training with a different number of GPUs (i.e. +different world size). For example, if we have 8 GPUs and 1 worker per rank then we have 8 splits. If we +later train with 4 GPUs then we will only have 4 splits and the data will be divided differently. This could +lead to a different model being trained which can make deterministic training harder to reproduce. + +To work around this, you can manually specify the number of splits used. This allows you to select a larger +number of splits than the default and can provide you with a property called "elastic determinism". If we +consider our example above, if we are going to train on 4 GPUs we can set `num_splits=8` and we will divide +the data as if we had 8 GPUs. Each rank will be assigned 2 splits and will pull from those two splits in +a round-robin fashion. This means each global batch that gets generated will be the same as the global batches +that were generated when we trained with 8 GPUs. + +In order for this determinism to work, the number of splits must be a multiple of the world size. Our simple +example above works because 8 is a multiple of 4. However, what would happen if we trained with 8 GPUs and then +wanted to train with 6 GPUs. In that case, `num_splits=8` would not be a multiple of 6 and the data would not +be divided evenly across the ranks. To make this work we can choose a `num_splits` value that is a multiple of +both 8 and 6. For example, we could choose `num_splits=24`. Good choices for `num_splits` are highly composite +numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10, +12, 15, 20, 30, 40, and 60 GPUs). -# We can then use this permutation table to create a Permutation object, this will now -# access the data in a random order. -permutation = Permutation.from_tables(table, permutation_tbl) + +```py Python icon=Python +import torch.distributed as dist + +dist.init_process_group("nccl") +rank = dist.get_rank() +world_size = dist.get_world_size() + +# num_splits=48 is divisible by 1, 2, 3, 4, 6, 8, 12, 16, 24, 48 +# so this dataset works unchanged as you scale up or down GPUs. +ds = StreamingDataset( + table, + num_splits=48, + shuffle_seed=42, + epoch=current_epoch, + rank=rank, + world_size=world_size, +) + +for sample in ds: + train_step(sample) ``` -## Selecting columns +### Checkpointing and resumability -By default, permutations will return all columns in the table. If you only need a subset of the columns, you can -significantly reduce your I/O requirements by selecting only the columns you need. This can be done on the -permutation object itself, and does not require us to create a separate permutation table. +Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save +the model state and resume training from where you left off. Most modern deep learning frameworks support checkpointing +models. However, we must also be able to checkpoint the data loader so that we can resume training from where we left off. +To support this the streaming dataset provides the `state_dict` and `load_state_dict` methods so that you can save and +load the state of the data loader. These methods should be called by your training framework when you want to save or +load a checkpoint. The `state_dict` method returns a simple python dictionary that can easily be persisted. -```py Python icon=Python -from lancedb.permutation import Permutation +```py Python icon=Python +import torch + +ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size) + +for step, sample in enumerate(ds): + train_step(sample) + + if step % checkpoint_interval == 0: + torch.save( + {"model": model.state_dict(), "dataloader": ds.state_dict()}, + f"checkpoint_{step}.pt", + ) -# We can select only the columns we need. -permutation = Permutation.identity(table).select_columns(["id", "prompt"]) +# --- resuming after a crash --- +checkpoint = torch.load("checkpoint_100.pt") +model.load_state_dict(checkpoint["model"]) + +ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size) +ds.load_state_dict(checkpoint["dataloader"]) + +for sample in ds: # continues from step 100, no repeated or skipped rows + train_step(sample) ``` - \ No newline at end of file + + +## Permutations + +In some more complicated scenarios, you may want to flexibility of the `StreamingDataset` to shuffle, split, and select +data while not buying into the full behavior of the iterable dataset. In these cases you can use the `Permutation` +class. This is a lower level class which the `StreamingDataset` is built on top of. A `Permutation` can be used to +define a custom ordering of the data. You can then index into the `Permutation` using the `__getitems__` method to +access rows by their ordering in the `Permutation`. More details on the `Permutation` class can be found in the API +reference.