Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions docs/training/torch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ The `Table` class in LanceDB implements a contract for a PyTorch
import lancedb
import torch
import pyarrow as pa
from lancedb.util import tbl_to_tensor

mem_db = lancedb.connect("memory://")
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))

# Any LanceDB table can be used as a PyTorch Dataset
dataloader = torch.utils.data.DataLoader(
table, batch_size=1024, shuffle=True
table, batch_size=1024, shuffle=True, collate_fn=tbl_to_tensor
)

for batch in dataloader:
Expand All @@ -42,12 +43,17 @@ dataloader = torch.utils.data.DataLoader(permutation)

## Output Formats

By default, a `Table` data loader will emit a `pyarrow.RecordBatch`. To convert to a different format (such as a
`pytorch.Tensor`), you will need to provide a custom collate function.
By default, a `Table` data loader will emit Arrow data. `collate_fn` is PyTorch's batching hook: PyTorch calls it to
turn the fetched items into one batch. PyTorch's default collate function only knows how to combine tensors, NumPy
arrays, numbers, dicts, and lists, so it does not accept Arrow data directly. When using a `Table` directly, pass
LanceDB's `lancedb.util.tbl_to_tensor` helper as PyTorch's `collate_fn`; it converts numeric Arrow columns into a
column-major `torch.Tensor` with shape `(columns, rows)`.

The `Permutation` class is more flexible. By default, the output will be a list of dicts. This is the default output
format of standard data loaders and usually more convenient when you are getting started. However, there is a
significant performance penalty converting from Arrow, Lance's internal representation, to this default format.
`Permutation` works differently: its default output is a list of Python dicts, which PyTorch's default collate function
can batch into a dict of tensors. This is usually more convenient when you are getting started. However, there is a
significant performance penalty converting from Arrow, Lance's internal representation, to this default format. Use a
direct `Table` with `collate_fn` when you want Arrow-to-tensor conversion, or a `Permutation` when you want the default
PyTorch dict-of-tensors behavior.
Comment on lines +54 to +56

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, Permutation can also support direct Arrow-to-tensor conversion. It just isn't the default. This makes it sound like you'd have to use a Table.


To address this, the `Permutation` class provides a set of builtin transform functions that can be applied to map
the Arrow data in different ways. The `arrow` and `polars` formats will always avoid data copies. However, `numpy`,
Expand Down Expand Up @@ -96,3 +102,84 @@ dataloader = torch.utils.data.DataLoader(
for batch in dataloader:
print(batch.schema)
```

## Using multiple DataLoader workers

Set `num_workers > 0` to read from LanceDB in multiple PyTorch worker processes. LanceDB tables and `Permutation` objects are picklable, so each worker reopens the table after it starts.

Prefer the `spawn` start method when using multiple workers; LanceDB uses internal threads. See [the performance guide](/performance) for more multiprocessing guidance.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually forkserver is probably better than spawn (and will be the new python default). I'd say that should be our preference.

Once the streaming dataset is available my guidance would be to use forkserver and use num_workers=1 unless you can prove you have GIL contention in your trasform function.


```py Python icon=Python
import torch
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=1024,
shuffle=True,
num_workers=4,
multiprocessing_context="spawn",
persistent_workers=True,
)
```

### Remote tables in DataLoader workers

Remote LanceDB Enterprise tables (`db://...`) work the same way: workers reopen the table from the pickled connection state.

```py Python icon=Python
import lancedb
import torch
from lancedb.util import tbl_to_tensor

db = lancedb.connect(
"db://my-database",
api_key="sk-...",
region="us-east-1",
)
table = db.open_table("my_table")

dataloader = torch.utils.data.DataLoader(
table,
batch_size=512,
num_workers=4,
multiprocessing_context="spawn",
collate_fn=tbl_to_tensor,
)
```

<Note>
This sends the connection state, including the API key, to each worker. Use a connection factory if credentials should be loaded inside the worker or your `client_config` contains a non-serializable `header_provider`.
</Note>

### Providing a custom connection factory

`Permutation.with_connection_factory` lets each worker reopen the base table with custom logic. The factory takes the table name, returns a LanceDB table, and must be picklable.

```py Python icon=Python
import os
import lancedb
import torch
from lancedb.permutation import Permutation

def open_table(name: str):
db = lancedb.connect(
"db://my-database",
api_key=os.environ["LANCEDB_API_KEY"],
region="us-east-1",
)
return db.open_table(name)

table = open_table("my_table")
permutation = (
Permutation.identity(table)
.with_connection_factory(open_table)
)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=512,
num_workers=4,
multiprocessing_context="spawn",
)
```
Loading