Skip to content

Commit 9b86370

Browse files
Merge branch 'hillst/block-sampling3' of /home/skothenhill/bionemo-framework-fresh/bionemo-framework/. into hillst/block-sampling3
Signed-off-by: Steven <skothenhill@nvidia.com>
2 parents 2eacb12 + 2fce592 commit 9b86370

5 files changed

Lines changed: 185 additions & 25 deletions

File tree

sub-packages/bionemo-core/src/bionemo/core/data/multi_epoch_dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import math
1818
from abc import ABC, abstractmethod
1919
from dataclasses import dataclass
20-
from typing import Generic, NamedTuple, Protocol, Sequence, TypeVar
20+
from typing import Any, Generic, NamedTuple, Protocol, Sequence, TypeVar
2121

2222
import numpy as np
2323
from torch.utils.data import Dataset
@@ -130,10 +130,17 @@ def __post_init__(self):
130130

131131
def __getitem__(self, index: int) -> T_co:
132132
"""Get the sample at the given index."""
133-
if index not in range(len(self)):
133+
if index < 0 or index >= len(self):
134134
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}.")
135135
return self.dataset[self._global_index_to_permuted_local_index(index)]
136136

137+
def __getitems__(self, indices: list[int]) -> Any:
138+
"""Get the samples at the given indices."""
139+
if hasattr(self.dataset, '__getitems__'):
140+
return self.dataset.__getitems__([self[i] for i in indices])
141+
else:
142+
return [self[i] for i in indices]
143+
137144
def __len__(self) -> int:
138145
"""Return the length of the resampled dataset."""
139146
return self.num_samples # type: ignore

sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/block_sampling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,11 @@ def __getitems__(self, indices: List[int]) -> Any:
145145
_sorted_order = np.argsort(shuffled_ids)
146146
_sorted_idxs = np.sort(shuffled_ids)
147147

148-
# Sort for I/O locality as we use blocked fetches.
149-
sorted_data = self.dataset[_sorted_idxs]
148+
# Turn it back into a list so torch does the right things.
149+
if hasattr(self.dataset, '__getitems__'):
150+
sorted_data = self.dataset.__getitems__(_sorted_idxs.tolist())
151+
else:
152+
sorted_data = [self.dataset[idx] for idx in _sorted_idxs.tolist()]
150153

151154
# Reverse the sorting to return the args in the original state.
152155
data = np.array(sorted_data)[np.argsort(_sorted_order)]
@@ -361,7 +364,6 @@ def __iter__(self):
361364
# Other workers get the base number of fetches
362365
start = worker_info.id * per_worker + remainder
363366
end = start + per_worker
364-
365367
fetches = fetches[start:end]
366368

367369
if self.sort_before_fetch:
@@ -374,15 +376,14 @@ def __iter__(self):
374376
if self.fetch_callback is not None:
375377
data = self.fetch_callback(self.collection, fetch_ids)
376378
else:
377-
data = self.collection[fetch_ids]
379+
data = list(self.collection[i] for i in fetch_ids)
378380

379381
if not isinstance(data, np.ndarray):
380382
data = np.array(data)
381383

382384
# Call fetch transform if provided
383385
if self.fetch_transform is not None:
384386
data = self.fetch_transform(data)
385-
386387
if self.shuffle_before_yield:
387388
# Shuffle the indices
388389
if bionemo_permute:
@@ -408,7 +409,6 @@ def __iter__(self):
408409
# Call batch transform if provided
409410
if self.batch_transform is not None:
410411
batch_data = self.batch_transform(batch_data)
411-
412412
yield batch_data
413413

414414
else: # Not shuffling indices before fetching

sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import List, Literal, Optional, Sequence
2020

2121
import numpy as np
22+
from bionemo.geneformer.data.block_sampling import MapStyleScDataset
2223
from nemo.lightning.data import WrappedDataLoader
2324
from nemo.lightning.pytorch.plugins import MegatronDataSampler
2425
from nemo.utils import logging
@@ -84,6 +85,8 @@ def __init__( # noqa: D107
8485
persistent_workers: bool = True,
8586
pin_memory: bool = True,
8687
include_unrecognized_vocab_in_dataset: bool = False,
88+
block_size: Optional[int] = None,
89+
fetch_factor: Optional[int] = None,
8790
) -> None:
8891
super().__init__()
8992
if predict_dataset_path is None:
@@ -111,10 +114,16 @@ def __init__( # noqa: D107
111114
self.num_workers = num_workers
112115
self.persistent_workers = persistent_workers
113116
self.pin_memory = pin_memory
117+
self.global_batch_size = global_batch_size
118+
# Block sampling parameters
119+
self.block_size = block_size
120+
self.fetch_factor = fetch_factor
121+
self.block_sampling = block_size and fetch_factor
114122

115123
rng = np.random.default_rng(seed)
116124
if self.data_path_train is not None:
117125
assert self.data_path_val is not None and self.data_path_test is not None
126+
118127
self._train_dataset_ori = SingleCellDataset(
119128
self.data_path_train,
120129
self.tokenizer,
@@ -201,12 +210,31 @@ def setup(self, stage: str = "") -> None: # noqa: D102
201210
num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
202211

203212
# This happens exactly once during setup.
204-
self._train_ds = MultiEpochDatasetResampler(
205-
self._train_dataset_ori,
206-
num_samples=num_train_samples,
207-
shuffle=True,
208-
seed=self.seed,
209-
)
213+
if self.block_sampling:
214+
# We also need associated block sampling parameters.
215+
216+
# dataset size must divide block size * batch_size
217+
if num_train_samples % (self.global_batch_size * self.block_size) != 0:
218+
# Warning
219+
num_train_samples -= num_train_samples % (self.global_batch_size * self.block_size)
220+
221+
222+
from bionemo.geneformer.data.block_sampling import MapStyleScDataset
223+
from bionemo.core.data.multi_epoch_dataset import MultiEpochDatasetResampler
224+
225+
self._train_ds = MultiEpochDatasetResampler(
226+
self._train_dataset_ori,
227+
num_samples=num_train_samples,
228+
shuffle=False,
229+
seed=self.seed,
230+
)
231+
self._train_ds = MapStyleScDataset(
232+
self._train_ds,
233+
block_size=self.block_size,
234+
batch_size=self.block_sampling,
235+
fetch_factor=self.fetch_factor,
236+
seed=self.seed * 2,
237+
)
210238
if self.trainer.limit_val_batches == 0: # disable validation
211239
logging.info("Skip creating validation dataset because trainer.limit_val_batches=0.")
212240
else:

sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def __len__(self): # noqa: D105
119119

120120
def __getitem__(self, index: EpochIndex) -> types.BertSample:
121121
"""Performs a lookup and the required transformation for the model."""
122+
if not isinstance(index, EpochIndex):
123+
index = EpochIndex(idx=index, epoch=0)
122124
rng = np.random.default_rng([self._seed, index.epoch, index.idx])
123125
values, feature_ids = self.scdl.get_row(index.idx, return_features=True, feature_vars=["feature_id"])
124126
assert (
@@ -145,7 +147,6 @@ def __getitem__(self, index: EpochIndex) -> types.BertSample:
145147
include_unrecognized_vocab_in_dataset=self.include_unrecognized_vocab_in_dataset,
146148
)
147149

148-
149150
def _gather_medians(
150151
gene_names: np.ndarray,
151152
gene_data: np.ndarray,
@@ -155,16 +156,12 @@ def _gather_medians(
155156
include_unrecognized_vocab_in_dataset: bool = False,
156157
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
157158
"""Filter out genes that are not in the provided tokenizer vocab, and tokenize the gene names."""
158-
genes, tokens, medians = [], [], []
159-
for tok, gene in zip(gene_names, gene_data):
160-
if tok in vocab:
161-
tokens.append(vocab[tok])
162-
genes.append(gene)
163-
if normalize:
164-
med = gene_median[tok] # If not in the dictionary we default to no normalization (1)
165-
medians.append(med)
166-
elif include_unrecognized_vocab_in_dataset:
167-
raise ValueError(f"Provided gene identifier, {str(tok)}, is not in the tokenizer vocab.")
159+
tok_genes = filter(
160+
lambda x: x[0] is not None,
161+
((vocab.get(tok), gene, gene_median.get(tok, 1.0)) for tok, gene in zip(gene_names, gene_data))
162+
)
163+
164+
tokens, genes, medians = zip(*tok_genes)
168165
return np.asarray(genes), np.asarray(tokens), np.asarray(medians)
169166

170167

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from pathlib import Path
2+
from bionemo.core.data.load import load
3+
from bionemo.geneformer.data.block_sampling import MapStyleScDataset, scDataset
4+
from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset
5+
from torch.utils.data import DataLoader
6+
from bionemo.core.data.multi_epoch_dataset import MultiEpochDatasetResampler
7+
import time
8+
import tqdm
9+
import functools
10+
from bionemo.llm.data import collate
11+
12+
from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess
13+
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
14+
def make_dataset():
15+
data_path: Path = load("single_cell/testdata-20241203") / "cellxgene_2023-12-15_small_processed_scdl" / "train"
16+
17+
train_data_path = Path("/home/ubuntu/data/cellxgene_2023-12-15/train")
18+
19+
preprocessor = GeneformerPreprocess(
20+
download_directory=train_data_path,
21+
medians_file_path=train_data_path / "medians.json",
22+
tokenizer_vocab_path=train_data_path / "geneformer.vocab",
23+
)
24+
match preprocessor.preprocess():
25+
case {"tokenizer": tokenizer, "median_dict": median_dict}:
26+
tokenizer, median_dict = tokenizer, median_dict
27+
case _:
28+
raise ValueError("Preprocessing must have failed.")
29+
30+
dataset = SingleCellDataset(train_data_path, tokenizer=tokenizer, median_dict=median_dict, max_len=2048)
31+
print("done loading ds")
32+
return dataset
33+
34+
def get_configs():
35+
return [
36+
{
37+
"block_size": 64,
38+
"batch_size": 128 * 8,
39+
"fetch_factor": 8,
40+
"seed": 42
41+
}
42+
]
43+
44+
def mapstyle_throughput():
45+
dataset = make_dataset()
46+
tokenizer = dataset.tokenizer
47+
48+
configs = get_configs()
49+
for config in configs:
50+
factor = config["fetch_factor"] * config["batch_size"]
51+
extra = len(dataset) % factor
52+
to_add = factor - extra
53+
num_samples = (len(dataset) + to_add)
54+
55+
dataset = MultiEpochDatasetResampler(
56+
dataset,
57+
num_samples=num_samples,
58+
shuffle=False,
59+
)
60+
'''
61+
When we stack the datasets this way, a whole vector is passed into getitem for
62+
MultiEpochDatasetResampler
63+
'''
64+
dataset = MapStyleScDataset(dataset, **config)
65+
66+
start = time.time()
67+
dataloader = DataLoader(dataset, batch_size=config["batch_size"], num_workers=16,
68+
collate_fn=functools.partial(
69+
collate.bert_padding_collate_fn,
70+
padding_value=tokenizer.token_to_id(GeneTokenizer.pad_token),
71+
min_length=2048,
72+
max_length=2048,
73+
),
74+
75+
)
76+
77+
78+
for i, batch in enumerate(tqdm.tqdm(dataloader)):
79+
if i > 100 * config["fetch_factor"]:
80+
break
81+
pass
82+
83+
end = time.time()
84+
print(f"MapStyleScDataset: {end - start} seconds")
85+
print(f"MapStyleScDataset: { 800 * config['batch_size'] / ( end - start)} samples per second")
86+
87+
def iterstyle_throughput():
88+
dataset = make_dataset()
89+
tokenizer = dataset.tokenizer
90+
91+
configs = get_configs()
92+
93+
for config in configs:
94+
num_samples = (len(dataset) - (len(dataset) % (config["batch_size"] * config["block_size"]))) * 2
95+
dataset = MultiEpochDatasetResampler(
96+
dataset,
97+
num_samples=num_samples,
98+
shuffle=False,
99+
)
100+
# TODO get some intermediate metrics
101+
dataset = scDataset(dataset, bionemo_permute=False, **config)
102+
103+
start = time.time()
104+
dataloader = DataLoader(dataset, batch_size=None, num_workers=16, shuffle=False,
105+
collate_fn=functools.partial(
106+
collate.bert_padding_collate_fn,
107+
padding_value=tokenizer.token_to_id(GeneTokenizer.pad_token),
108+
min_length=2048,
109+
max_length=2048,
110+
),
111+
)
112+
113+
# I think this just happens if its not an even multiple of the batch size
114+
try:
115+
for i, batch in enumerate(tqdm.tqdm(dataloader)):
116+
if i > 100 * config["fetch_factor"]:
117+
break
118+
pass
119+
except RuntimeError as e:
120+
print(e)
121+
122+
end = time.time()
123+
print(f"IterStyleDataset: {end - start} seconds")
124+
print(f"IterStyleDataset: { 800 * config['batch_size']/ ( end - start)} samples per second")
125+
126+
if __name__ == "__main__":
127+
mapstyle_throughput()
128+
iterstyle_throughput()

0 commit comments

Comments
 (0)