Skip to content

Commit 6574f52

Browse files
committed
addressing coderabbit review
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 1d35234 commit 6574f52

25 files changed

Lines changed: 383 additions & 216 deletions

File tree

bionemo-recipes/models/amplify/src/amplify/state.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_transforms(
6767
source: Union[nn.Module, _ModelState],
6868
target: TargetModuleT,
6969
mapping: Dict[str, str],
70-
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
71-
state_dict_ignored_entries: List = [],
70+
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
71+
state_dict_ignored_entries: Optional[List] = None,
7272
cast_dtype: Optional[torch.dtype] = None,
7373
) -> TargetModuleT:
7474
"""Transform the state dictionary of a source module to match the structure of a target module's state dictionary.
@@ -126,6 +126,11 @@ def scale_weights(ctx):
126126
This function is particularly useful when adapting models from different frameworks or
127127
when consolidating models with different architectural changes.
128128
"""
129+
if transforms is None:
130+
transforms = []
131+
if state_dict_ignored_entries is None:
132+
state_dict_ignored_entries = []
133+
129134
# Track dtypes to make sure they weren't modified during conversion.
130135
target_orig_dtypes = extract_dtypes(target.named_parameters())
131136

@@ -318,7 +323,7 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
318323
try:
319324
source_match = source_matches[target_index]
320325
except IndexError as e:
321-
logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
326+
logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
322327
raise e
323328
if accepts_var_args:
324329
source_values = [source_dict[k] for k in source_match]

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,11 @@ def __call__(self, features, return_tensors=None):
154154
sequence processing capabilities. When pad_to_multiple_of is used, an additional
155155
mock sequence is appended to reach the desired total length.
156156
"""
157+
if return_tensors is not None and return_tensors != "pt":
158+
raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'")
159+
157160
# Perform the masking with the BSHD collator.
158-
bshd_batch = self.collator(features)
161+
bshd_batch = self.collator(features, return_tensors=return_tensors)
159162

160163
# Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
161164
packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids)
@@ -247,29 +250,68 @@ def __iter__(self):
247250
samples = []
248251
current_length = 0
249252
for sample in iter(self.dataset):
250-
current_length += len(sample["input_ids"])
253+
sample_length = len(sample["input_ids"])
254+
current_length += sample_length
255+
251256
if current_length == self.max_tokens_per_batch:
252257
yield [*samples, sample]
253258
samples = []
254259
current_length = 0
255260

256261
elif current_length > self.max_tokens_per_batch:
257-
if not self.split_samples:
258-
# If we are not splitting samples, we can just yield the current batch (before this sample) and
259-
# start a new one.
260-
yield samples
261-
samples = [sample]
262+
tokens_available = self.max_tokens_per_batch - (current_length - sample_length)
263+
264+
if tokens_available <= 0:
265+
# Current batch is already full (or over); yield it first, then handle this sample.
266+
if samples:
267+
yield samples
268+
samples = []
269+
current_length = sample_length
270+
tokens_available = self.max_tokens_per_batch
271+
272+
# Now handle the incoming sample with a fresh batch.
273+
if sample_length == self.max_tokens_per_batch:
274+
yield [sample]
275+
samples = []
276+
current_length = 0
277+
continue
278+
elif sample_length < self.max_tokens_per_batch:
279+
samples = [sample]
280+
continue
281+
# sample_length > max_tokens_per_batch: fall through to split logic below
262282

283+
if not self.split_samples:
284+
# Yield the current batch (before this sample) and start a new one with this sample.
285+
if samples:
286+
yield samples
287+
# The sample itself may exceed max_tokens_per_batch; yield it as its own batch.
288+
if sample_length > self.max_tokens_per_batch:
289+
yield [sample]
290+
samples = []
291+
current_length = 0
292+
else:
293+
samples = [sample]
294+
current_length = sample_length
263295
else:
264-
# Calculate how many tokens are already in the batch
265-
tokens_in_batch = current_length - len(sample["input_ids"])
266-
# Calculate how many tokens we can fit from this sample
267-
tokens_available = self.max_tokens_per_batch - tokens_in_batch
268-
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
269-
yield [*samples, first_part]
270-
samples = [remaining_part]
271-
272-
current_length = len(samples[0]["input_ids"])
296+
# Split mode: fill the current batch, then split remaining into chunks.
297+
if tokens_available > 0 and tokens_available < sample_length:
298+
first_part, remaining = _split_sample_by_num_tokens(sample, tokens_available)
299+
yield [*samples, first_part]
300+
else:
301+
# tokens_available >= sample_length shouldn't happen here, but guard anyway
302+
if samples:
303+
yield samples
304+
remaining = sample
305+
306+
# Now split the remaining part into chunks of max_tokens_per_batch.
307+
while len(remaining["input_ids"]) > self.max_tokens_per_batch:
308+
chunk, remaining = _split_sample_by_num_tokens(remaining, self.max_tokens_per_batch)
309+
yield [chunk]
310+
311+
samples = [remaining]
312+
current_length = len(remaining["input_ids"])
313+
continue
314+
273315
else:
274316
samples.append(sample)
275317

@@ -345,7 +387,8 @@ def __call__(self, features) -> list[dict[str, Any]]:
345387
else:
346388
raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!")
347389

348-
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64)
390+
padded_max = ((max_length + 63) // 64) * 64
391+
batch_shard["max_length_k"] = batch_shard["max_length_q"] = padded_max
349392
combined_batch.append(batch_shard)
350393

351394
return combined_batch

bionemo-recipes/models/esm2/src/esm/state.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_transforms(
6767
source: Union[nn.Module, _ModelState],
6868
target: TargetModuleT,
6969
mapping: Dict[str, str],
70-
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
71-
state_dict_ignored_entries: List = [],
70+
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
71+
state_dict_ignored_entries: Optional[List] = None,
7272
cast_dtype: Optional[torch.dtype] = None,
7373
) -> TargetModuleT:
7474
"""Transform the state dictionary of a source module to match the structure of a target module's state dictionary.
@@ -126,6 +126,11 @@ def scale_weights(ctx):
126126
This function is particularly useful when adapting models from different frameworks or
127127
when consolidating models with different architectural changes.
128128
"""
129+
if transforms is None:
130+
transforms = []
131+
if state_dict_ignored_entries is None:
132+
state_dict_ignored_entries = []
133+
129134
# Track dtypes to make sure they weren't modified during conversion.
130135
target_orig_dtypes = extract_dtypes(target.named_parameters())
131136

@@ -318,7 +323,7 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
318323
try:
319324
source_match = source_matches[target_index]
320325
except IndexError as e:
321-
logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
326+
logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
322327
raise e
323328
if accepts_var_args:
324329
source_values = [source_dict[k] for k in source_match]

bionemo-recipes/models/esm2/tests/common/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Shared test infrastructure for BioNeMo models. One base class, **BaseModelTest**
44

55
## Structure
66

7-
```
7+
```text
88
tests/common/
99
├── __init__.py # Public API exports
1010
├── test_modeling_common.py # BaseModelTest, TestTolerances

bionemo-recipes/models/esm2/tests/common/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Common test utilities for BioNeMo models.
3217
3318
This package provides reusable test infrastructure following HuggingFace

bionemo-recipes/models/esm2/tests/common/fixtures.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: Apache-2.0
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
30-
3116
"""Shared test fixtures for BioNeMo models."""
3217

3318
import os
@@ -63,7 +48,7 @@ def use_te_debug():
6348

6449
os.environ["NVTE_DEBUG"] = "1"
6550
yield
66-
del os.environ["NVTE_DEBUG"]
51+
os.environ.pop("NVTE_DEBUG", None)
6752

6853

6954
ALL_RECIPES = [
@@ -138,6 +123,6 @@ def te_attn_backend(request):
138123

139124
yield request.param
140125

141-
del os.environ["NVTE_FUSED_ATTN"]
142-
del os.environ["NVTE_FLASH_ATTN"]
126+
os.environ.pop("NVTE_FUSED_ATTN", None)
127+
os.environ.pop("NVTE_FLASH_ATTN", None)
143128
_attention_backends["backend_selection_requires_update"] = True

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@
3131
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, set_seed
3232

3333

34-
HAS_DATA_CENTER_GPU = any(
35-
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
36-
)
34+
try:
35+
HAS_DATA_CENTER_GPU = torch.cuda.is_available() and any(
36+
gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"]
37+
)
38+
except (RuntimeError, AssertionError):
39+
HAS_DATA_CENTER_GPU = False
3740

3841

3942
@dataclass
@@ -343,13 +346,14 @@ def get_reference_model(
343346
model.to("cuda")
344347
return model
345348

346-
def get_reference_model_no_weights(self) -> PreTrainedModel:
349+
def get_reference_model_no_weights(self, **kwargs) -> PreTrainedModel:
347350
"""Load the reference HuggingFace model with random weights."""
348351
return self.get_upstream_model_class()(
349352
AutoConfig.from_pretrained(
350353
self.get_upstream_model_id(),
351354
dtype=torch.float32,
352355
revision=self.get_upstream_model_revision(),
356+
**kwargs,
353357
)
354358
)
355359

bionemo-recipes/models/llama3/collator.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,11 @@ def __call__(self, features, return_tensors=None):
154154
sequence processing capabilities. When pad_to_multiple_of is used, an additional
155155
mock sequence is appended to reach the desired total length.
156156
"""
157+
if return_tensors is not None and return_tensors != "pt":
158+
raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'")
159+
157160
# Perform the masking with the BSHD collator.
158-
bshd_batch = self.collator(features)
161+
bshd_batch = self.collator(features, return_tensors=return_tensors)
159162

160163
# Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
161164
packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids)
@@ -247,29 +250,68 @@ def __iter__(self):
247250
samples = []
248251
current_length = 0
249252
for sample in iter(self.dataset):
250-
current_length += len(sample["input_ids"])
253+
sample_length = len(sample["input_ids"])
254+
current_length += sample_length
255+
251256
if current_length == self.max_tokens_per_batch:
252257
yield [*samples, sample]
253258
samples = []
254259
current_length = 0
255260

256261
elif current_length > self.max_tokens_per_batch:
257-
if not self.split_samples:
258-
# If we are not splitting samples, we can just yield the current batch (before this sample) and
259-
# start a new one.
260-
yield samples
261-
samples = [sample]
262+
tokens_available = self.max_tokens_per_batch - (current_length - sample_length)
263+
264+
if tokens_available <= 0:
265+
# Current batch is already full (or over); yield it first, then handle this sample.
266+
if samples:
267+
yield samples
268+
samples = []
269+
current_length = sample_length
270+
tokens_available = self.max_tokens_per_batch
271+
272+
# Now handle the incoming sample with a fresh batch.
273+
if sample_length == self.max_tokens_per_batch:
274+
yield [sample]
275+
samples = []
276+
current_length = 0
277+
continue
278+
elif sample_length < self.max_tokens_per_batch:
279+
samples = [sample]
280+
continue
281+
# sample_length > max_tokens_per_batch: fall through to split logic below
262282

283+
if not self.split_samples:
284+
# Yield the current batch (before this sample) and start a new one with this sample.
285+
if samples:
286+
yield samples
287+
# The sample itself may exceed max_tokens_per_batch; yield it as its own batch.
288+
if sample_length > self.max_tokens_per_batch:
289+
yield [sample]
290+
samples = []
291+
current_length = 0
292+
else:
293+
samples = [sample]
294+
current_length = sample_length
263295
else:
264-
# Calculate how many tokens are already in the batch
265-
tokens_in_batch = current_length - len(sample["input_ids"])
266-
# Calculate how many tokens we can fit from this sample
267-
tokens_available = self.max_tokens_per_batch - tokens_in_batch
268-
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
269-
yield [*samples, first_part]
270-
samples = [remaining_part]
271-
272-
current_length = len(samples[0]["input_ids"])
296+
# Split mode: fill the current batch, then split remaining into chunks.
297+
if tokens_available > 0 and tokens_available < sample_length:
298+
first_part, remaining = _split_sample_by_num_tokens(sample, tokens_available)
299+
yield [*samples, first_part]
300+
else:
301+
# tokens_available >= sample_length shouldn't happen here, but guard anyway
302+
if samples:
303+
yield samples
304+
remaining = sample
305+
306+
# Now split the remaining part into chunks of max_tokens_per_batch.
307+
while len(remaining["input_ids"]) > self.max_tokens_per_batch:
308+
chunk, remaining = _split_sample_by_num_tokens(remaining, self.max_tokens_per_batch)
309+
yield [chunk]
310+
311+
samples = [remaining]
312+
current_length = len(remaining["input_ids"])
313+
continue
314+
273315
else:
274316
samples.append(sample)
275317

@@ -345,7 +387,8 @@ def __call__(self, features) -> list[dict[str, Any]]:
345387
else:
346388
raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!")
347389

348-
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64)
390+
padded_max = ((max_length + 63) // 64) * 64
391+
batch_shard["max_length_k"] = batch_shard["max_length_q"] = padded_max
349392
combined_batch.append(batch_shard)
350393

351394
return combined_batch

0 commit comments

Comments
 (0)