diff --git a/docs/training/torch.mdx b/docs/training/torch.mdx index 9a9dbe5b..c04d2760 100644 --- a/docs/training/torch.mdx +++ b/docs/training/torch.mdx @@ -17,13 +17,14 @@ The `Table` class in LanceDB implements a contract for a PyTorch import lancedb import torch import pyarrow as pa +from lancedb.util import tbl_to_tensor mem_db = lancedb.connect("memory://") table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) # Any LanceDB table can be used as a PyTorch Dataset dataloader = torch.utils.data.DataLoader( - table, batch_size=1024, shuffle=True + table, batch_size=1024, shuffle=True, collate_fn=tbl_to_tensor ) for batch in dataloader: @@ -42,12 +43,17 @@ dataloader = torch.utils.data.DataLoader(permutation) ## Output Formats -By default, a `Table` data loader will emit a `pyarrow.RecordBatch`. To convert to a different format (such as a -`pytorch.Tensor`), you will need to provide a custom collate function. +By default, a `Table` data loader will emit Arrow data. `collate_fn` is PyTorch's batching hook: PyTorch calls it to +turn the fetched items into one batch. PyTorch's default collate function only knows how to combine tensors, NumPy +arrays, numbers, dicts, and lists, so it does not accept Arrow data directly. When using a `Table` directly, pass +LanceDB's `lancedb.util.tbl_to_tensor` helper as PyTorch's `collate_fn`; it converts numeric Arrow columns into a +column-major `torch.Tensor` with shape `(columns, rows)`. -The `Permutation` class is more flexible. By default, the output will be a list of dicts. This is the default output -format of standard data loaders and usually more convenient when you are getting started. However, there is a -significant performance penalty converting from Arrow, Lance's internal representation, to this default format. +`Permutation` works differently: its default output is a list of Python dicts, which PyTorch's default collate function +can batch into a dict of tensors. This is usually more convenient when you are getting started. However, there is a +significant performance penalty converting from Arrow, Lance's internal representation, to this default format. Use a +direct `Table` with `collate_fn` when you want Arrow-to-tensor conversion, or a `Permutation` when you want the default +PyTorch dict-of-tensors behavior. To address this, the `Permutation` class provides a set of builtin transform functions that can be applied to map the Arrow data in different ways. The `arrow` and `polars` formats will always avoid data copies. However, `numpy`, @@ -96,3 +102,84 @@ dataloader = torch.utils.data.DataLoader( for batch in dataloader: print(batch.schema) ``` + +## Using multiple DataLoader workers + +Set `num_workers > 0` to read from LanceDB in multiple PyTorch worker processes. LanceDB tables and `Permutation` objects are picklable, so each worker reopens the table after it starts. + +Prefer the `spawn` start method when using multiple workers; LanceDB uses internal threads. See [the performance guide](/performance) for more multiprocessing guidance. + +```py Python icon=Python +import torch +from lancedb.permutation import Permutation + +permutation = Permutation.identity(table) +dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=1024, + shuffle=True, + num_workers=4, + multiprocessing_context="spawn", + persistent_workers=True, +) +``` + +### Remote tables in DataLoader workers + +Remote LanceDB Enterprise tables (`db://...`) work the same way: workers reopen the table from the pickled connection state. + +```py Python icon=Python +import lancedb +import torch +from lancedb.util import tbl_to_tensor + +db = lancedb.connect( + "db://my-database", + api_key="sk-...", + region="us-east-1", +) +table = db.open_table("my_table") + +dataloader = torch.utils.data.DataLoader( + table, + batch_size=512, + num_workers=4, + multiprocessing_context="spawn", + collate_fn=tbl_to_tensor, +) +``` + + +This sends the connection state, including the API key, to each worker. Use a connection factory if credentials should be loaded inside the worker or your `client_config` contains a non-serializable `header_provider`. + + +### Providing a custom connection factory + +`Permutation.with_connection_factory` lets each worker reopen the base table with custom logic. The factory takes the table name, returns a LanceDB table, and must be picklable. + +```py Python icon=Python +import os +import lancedb +import torch +from lancedb.permutation import Permutation + +def open_table(name: str): + db = lancedb.connect( + "db://my-database", + api_key=os.environ["LANCEDB_API_KEY"], + region="us-east-1", + ) + return db.open_table(name) + +table = open_table("my_table") +permutation = ( + Permutation.identity(table) + .with_connection_factory(open_table) +) +dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=512, + num_workers=4, + multiprocessing_context="spawn", +) +```