-
Notifications
You must be signed in to change notification settings - Fork 11
docs: document multi-worker DataLoader support for remote tables #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6824cd7
f449aa6
70291f2
55e040f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,13 +17,14 @@ The `Table` class in LanceDB implements a contract for a PyTorch | |
| import lancedb | ||
| import torch | ||
| import pyarrow as pa | ||
| from lancedb.util import tbl_to_tensor | ||
|
|
||
| mem_db = lancedb.connect("memory://") | ||
| table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) | ||
|
|
||
| # Any LanceDB table can be used as a PyTorch Dataset | ||
| dataloader = torch.utils.data.DataLoader( | ||
| table, batch_size=1024, shuffle=True | ||
| table, batch_size=1024, shuffle=True, collate_fn=tbl_to_tensor | ||
| ) | ||
|
|
||
| for batch in dataloader: | ||
|
|
@@ -42,12 +43,17 @@ dataloader = torch.utils.data.DataLoader(permutation) | |
|
|
||
| ## Output Formats | ||
|
|
||
| By default, a `Table` data loader will emit a `pyarrow.RecordBatch`. To convert to a different format (such as a | ||
| `pytorch.Tensor`), you will need to provide a custom collate function. | ||
| By default, a `Table` data loader will emit Arrow data. `collate_fn` is PyTorch's batching hook: PyTorch calls it to | ||
| turn the fetched items into one batch. PyTorch's default collate function only knows how to combine tensors, NumPy | ||
| arrays, numbers, dicts, and lists, so it does not accept Arrow data directly. When using a `Table` directly, pass | ||
| LanceDB's `lancedb.util.tbl_to_tensor` helper as PyTorch's `collate_fn`; it converts numeric Arrow columns into a | ||
| column-major `torch.Tensor` with shape `(columns, rows)`. | ||
|
|
||
| The `Permutation` class is more flexible. By default, the output will be a list of dicts. This is the default output | ||
| format of standard data loaders and usually more convenient when you are getting started. However, there is a | ||
| significant performance penalty converting from Arrow, Lance's internal representation, to this default format. | ||
| `Permutation` works differently: its default output is a list of Python dicts, which PyTorch's default collate function | ||
| can batch into a dict of tensors. This is usually more convenient when you are getting started. However, there is a | ||
| significant performance penalty converting from Arrow, Lance's internal representation, to this default format. Use a | ||
| direct `Table` with `collate_fn` when you want Arrow-to-tensor conversion, or a `Permutation` when you want the default | ||
| PyTorch dict-of-tensors behavior. | ||
|
|
||
| To address this, the `Permutation` class provides a set of builtin transform functions that can be applied to map | ||
| the Arrow data in different ways. The `arrow` and `polars` formats will always avoid data copies. However, `numpy`, | ||
|
|
@@ -96,3 +102,84 @@ dataloader = torch.utils.data.DataLoader( | |
| for batch in dataloader: | ||
| print(batch.schema) | ||
| ``` | ||
|
|
||
| ## Using multiple DataLoader workers | ||
|
|
||
| Set `num_workers > 0` to read from LanceDB in multiple PyTorch worker processes. LanceDB tables and `Permutation` objects are picklable, so each worker reopens the table after it starts. | ||
|
|
||
| Prefer the `spawn` start method when using multiple workers; LanceDB uses internal threads. See [the performance guide](/performance) for more multiprocessing guidance. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually Once the streaming dataset is available my guidance would be to use |
||
|
|
||
| ```py Python icon=Python | ||
| import torch | ||
| from lancedb.permutation import Permutation | ||
|
|
||
| permutation = Permutation.identity(table) | ||
| dataloader = torch.utils.data.DataLoader( | ||
| permutation, | ||
| batch_size=1024, | ||
| shuffle=True, | ||
| num_workers=4, | ||
| multiprocessing_context="spawn", | ||
| persistent_workers=True, | ||
| ) | ||
| ``` | ||
|
|
||
| ### Remote tables in DataLoader workers | ||
|
|
||
| Remote LanceDB Enterprise tables (`db://...`) work the same way: workers reopen the table from the pickled connection state. | ||
|
|
||
| ```py Python icon=Python | ||
| import lancedb | ||
| import torch | ||
| from lancedb.util import tbl_to_tensor | ||
|
|
||
| db = lancedb.connect( | ||
| "db://my-database", | ||
| api_key="sk-...", | ||
| region="us-east-1", | ||
| ) | ||
| table = db.open_table("my_table") | ||
|
|
||
| dataloader = torch.utils.data.DataLoader( | ||
| table, | ||
| batch_size=512, | ||
| num_workers=4, | ||
| multiprocessing_context="spawn", | ||
| collate_fn=tbl_to_tensor, | ||
| ) | ||
| ``` | ||
|
|
||
| <Note> | ||
| This sends the connection state, including the API key, to each worker. Use a connection factory if credentials should be loaded inside the worker or your `client_config` contains a non-serializable `header_provider`. | ||
| </Note> | ||
|
|
||
| ### Providing a custom connection factory | ||
|
|
||
| `Permutation.with_connection_factory` lets each worker reopen the base table with custom logic. The factory takes the table name, returns a LanceDB table, and must be picklable. | ||
|
|
||
| ```py Python icon=Python | ||
| import os | ||
| import lancedb | ||
| import torch | ||
| from lancedb.permutation import Permutation | ||
|
|
||
| def open_table(name: str): | ||
| db = lancedb.connect( | ||
| "db://my-database", | ||
| api_key=os.environ["LANCEDB_API_KEY"], | ||
| region="us-east-1", | ||
| ) | ||
| return db.open_table(name) | ||
|
|
||
| table = open_table("my_table") | ||
| permutation = ( | ||
| Permutation.identity(table) | ||
| .with_connection_factory(open_table) | ||
| ) | ||
| dataloader = torch.utils.data.DataLoader( | ||
| permutation, | ||
| batch_size=512, | ||
| num_workers=4, | ||
| multiprocessing_context="spawn", | ||
| ) | ||
| ``` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm,
Permutationcan also support direct Arrow-to-tensor conversion. It just isn't the default. This makes it sound like you'd have to use aTable.