Skip to content

Commit e22b2c5

Browse files
committed
claude code review and fixes
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 6a60786 commit e22b2c5

31 files changed

Lines changed: 1481 additions & 806 deletions

File tree

CLAUDE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ pre-commit run --all-files
4242
pre-commit run
4343
```
4444

45+
Do not copy license headers from other files, instead allow the license-check.py script to add the license header during
46+
pre-commit to ensure the proper year is used.
47+
4548
Pre-commit includes:
4649

4750
- Ruff linting/formatting (line-length: 119, Google-style docstrings)

bionemo-recipes/models/amplify/src/amplify/state.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,21 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""State dict conversion utilities adapted from nemo.lightning.io.state."""
16+
"""State dict conversion utilities adapted from nemo.lightning.io.state.
17+
18+
This module provides the transform system used by convert.py to map state dicts between model formats:
19+
20+
- ``mapping``: A dict of simple key renames (source_key -> target_key). Each source key is copied directly
21+
to the corresponding target key with no modification to the tensor values.
22+
23+
- ``transforms``: A list of ``StateDictTransform`` objects for multi-key merges and splits. These handle
24+
cases where multiple source keys must be combined into one target key (e.g., merging Q/K/V into fused QKV),
25+
or one source key must be split into multiple target keys.
26+
27+
Important: When ``source_key`` is a tuple (many-to-one merge), the transform function's parameter names
28+
are used to map each source key to a function argument. This means ``*args`` style parameters do not work;
29+
each parameter must be explicitly named (e.g., ``def fn(q, k, v)`` not ``def fn(*args)``).
30+
"""
1731

1832
import inspect
1933
import logging

bionemo-recipes/models/esm2/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
115115

116116
To validate the converted models, refer to the commands in [Inference Examples](#inference-examples) above to load and test both the original and converted
117117
models to ensure loss and logit values are similar. Additionally, refer to the golden value tests in
118-
[test_modeling_esm_te.py](tests/test_modeling_esm_te.py) and [test_convert.py](tests/test_convert.py).
118+
[test_modeling_esm_te.py](tests/test_modeling_esm_te.py) and [test_export.py](tests/test_export.py).
119119

120120
## Developer Guide
121121

bionemo-recipes/models/esm2/collator.py

Lines changed: 149 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,144 @@ def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_id
730730
return batch
731731

732732

733+
def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int:
734+
"""Find which dimension of tensor matches the expected sequence length.
735+
736+
Args:
737+
tensor: The tensor to inspect.
738+
seq_len: The expected sequence length to match against tensor dimensions.
739+
740+
Returns:
741+
The dimension index that matches the sequence length.
742+
743+
Raises:
744+
ValueError: If no dimension matches the expected sequence length.
745+
"""
746+
if tensor.ndim == 1:
747+
if tensor.shape[0] == seq_len:
748+
return 0
749+
raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}")
750+
elif tensor.ndim >= 2:
751+
if tensor.shape[1] == seq_len:
752+
return 1
753+
elif tensor.shape[0] == seq_len:
754+
return 0
755+
raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1")
756+
raise ValueError(f"Unexpected tensor ndim={tensor.ndim}")
757+
758+
759+
def _process_tensor_thd(
760+
val: torch.Tensor | None,
761+
seq_len: int,
762+
slice_sizes: torch.Tensor,
763+
cu_seqlens_padded: torch.Tensor,
764+
cp_rank: int,
765+
total_slices: int,
766+
) -> torch.Tensor | None:
767+
"""Extract the THD context-parallel shard for a single tensor.
768+
769+
For each sequence in the batch, selects two slices (one from the beginning and one from the end)
770+
corresponding to the given CP rank, following the zigzag CP sharding pattern.
771+
772+
Args:
773+
val: The tensor to shard, or None (returned as-is).
774+
seq_len: Total sequence length (from cu_seqlens_padded[-1]).
775+
slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices.
776+
cu_seqlens_padded: Cumulative sequence lengths including padding.
777+
cp_rank: The context parallelism rank index.
778+
total_slices: Total number of slices per sequence (2 * cp_world_size).
779+
780+
Returns:
781+
The sharded tensor for the given CP rank, or None if val is None.
782+
"""
783+
if val is None:
784+
return val
785+
786+
seq_dim = _find_seq_dim(val, seq_len)
787+
788+
cp_rank_slices = []
789+
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
790+
# 1st segment
791+
cp_rank_slices.append(
792+
torch.arange(
793+
seq_start + (cp_rank * slice_size),
794+
seq_start + ((cp_rank + 1) * slice_size),
795+
device=val.device,
796+
)
797+
)
798+
799+
# 2nd segment
800+
cp_rank_slices.append(
801+
torch.arange(
802+
seq_start + ((total_slices - cp_rank - 1) * slice_size),
803+
seq_start + ((total_slices - cp_rank) * slice_size),
804+
device=val.device,
805+
)
806+
)
807+
808+
return val.index_select(seq_dim, torch.cat(cp_rank_slices))
809+
810+
811+
def _process_tensor_bshd(
812+
val: torch.Tensor | None,
813+
cp_rank: int,
814+
cp_world_size: int,
815+
) -> torch.Tensor | None:
816+
"""Extract the BSHD context-parallel shard for a single tensor.
817+
818+
Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks,
819+
then selects the two chunks corresponding to the given CP rank (zigzag pattern).
820+
821+
Args:
822+
val: The tensor to shard, or None (returned as-is).
823+
cp_rank: The context parallelism rank index.
824+
cp_world_size: Total number of context parallelism ranks.
825+
826+
Returns:
827+
The sharded tensor for the given CP rank, or None if val is None.
828+
829+
Raises:
830+
ValueError: If the tensor has fewer than 2 dimensions or its sequence length
831+
is not divisible by 2 * cp_world_size.
832+
"""
833+
if val is None:
834+
return val
835+
836+
if val.ndim < 2:
837+
raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D")
838+
839+
seq_len = val.shape[1]
840+
841+
# Calculate chunk size
842+
total_chunks = 2 * cp_world_size
843+
chunk_size = seq_len // total_chunks
844+
845+
if seq_len % total_chunks != 0:
846+
raise ValueError(
847+
f"Sequence length {seq_len} must be divisible by {total_chunks} "
848+
f"(2 * cp_world_size) for BSHD context parallelism"
849+
)
850+
851+
# Determine which chunks this rank should get
852+
# Rank 0 gets chunks [0, total_chunks-1]
853+
# Rank 1 gets chunks [1, total_chunks-2]
854+
# Rank k gets chunks [k, total_chunks-k-1]
855+
chunk_indices = [cp_rank, total_chunks - cp_rank - 1]
856+
857+
# Collect slices for this rank
858+
rank_slices = []
859+
for chunk_idx in chunk_indices:
860+
start_idx = chunk_idx * chunk_size
861+
end_idx = start_idx + chunk_size
862+
rank_slices.append(torch.arange(start_idx, end_idx, device=val.device))
863+
864+
# Concatenate indices for all chunks this rank should get
865+
indices = torch.cat(rank_slices)
866+
867+
# Select along sequence dimension (dim=1)
868+
return val.index_select(1, indices)
869+
870+
733871
def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int):
734872
"""Pad a batch to a multiple of pad_to_multiple_of.
735873
@@ -837,110 +975,20 @@ def _split_batch_by_cp_rank(
837975
total_slices_of_any_sequence = 2 * cp_world_size
838976
slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence
839977

840-
# Process each tensor directly instead of using keys_to_change loop
841-
def process_tensor(val):
842-
if val is None:
843-
return val
844-
# Determine which dimension is the sequence dimension
845-
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
846-
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
847-
seq_len_val = cu_seqlens_padded[-1].item()
848-
else:
849-
seq_len_val = cu_seqlens_padded[-1]
850-
851-
# Handle 1D tensors (like position_ids that don't have batch dimension)
852-
if val.ndim == 1:
853-
if val.shape[0] == seq_len_val:
854-
current_seq_dim = 0
855-
else:
856-
raise ValueError(
857-
"1D tensor shape doesn't match expected sequence length. Make sure the"
858-
" inputs are in THD format and padded correctly."
859-
)
860-
elif val.ndim >= 2:
861-
if val.shape[1] == seq_len_val:
862-
current_seq_dim = 1
863-
elif val.shape[0] == seq_len_val:
864-
current_seq_dim = 0
865-
else:
866-
raise ValueError("Make sure the inputs are in THD format and padded correctly.")
867-
else:
868-
raise ValueError("Tensor must be at least 1D")
869-
870-
# On this particular rank, for each sequence, get two slices, one from the beginning
871-
# and one from the end.
872-
cp_rank_slices = []
873-
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
874-
# 1st segment
875-
cp_rank_slices.append(
876-
torch.arange(
877-
seq_start + (cp_rank * slice_size),
878-
seq_start + ((cp_rank + 1) * slice_size),
879-
device=val.device,
880-
)
881-
)
978+
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
979+
last_elem = cu_seqlens_padded[-1]
980+
seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem
882981

883-
# 2nd segment
884-
cp_rank_slices.append(
885-
torch.arange(
886-
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
887-
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
888-
device=val.device,
889-
)
890-
)
891-
892-
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
893-
894-
# Process each tensor directly
895-
input_ids_padded = process_tensor(input_ids_padded)
896-
labels_padded = process_tensor(labels_padded)
982+
input_ids_padded = _process_tensor_thd(
983+
input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence
984+
)
985+
labels_padded = _process_tensor_thd(
986+
labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence
987+
)
897988

898989
elif qvk_format == "bshd":
899-
# BSHD format: [batch, seq_len, ...]
900-
# Split along sequence dimension (dim=1)
901-
# Each sequence is split into 2*cp_world_size chunks
902-
# Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1]
903-
904-
def process_tensor_bshd(val):
905-
if val is None:
906-
return val
907-
908-
if val.ndim < 2:
909-
raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D")
910-
911-
seq_len = val.shape[1]
912-
913-
# Calculate chunk size
914-
total_chunks = 2 * cp_world_size
915-
chunk_size = seq_len // total_chunks
916-
917-
if chunk_size == 0:
918-
raise ValueError(
919-
f"Sequence length {seq_len} must be divisible by {total_chunks} "
920-
f"(2 * cp_world_size) for BSHD context parallelism"
921-
)
922-
923-
# Determine which chunks this rank should get
924-
# Rank 0 gets chunks [0, total_chunks-1]
925-
# Rank 1 gets chunks [1, total_chunks-2]
926-
# Rank k gets chunks [k, total_chunks-k-1]
927-
chunk_indices = [cp_rank, total_chunks - cp_rank - 1]
928-
929-
# Collect slices for this rank
930-
rank_slices = []
931-
for chunk_idx in chunk_indices:
932-
start_idx = chunk_idx * chunk_size
933-
end_idx = start_idx + chunk_size
934-
rank_slices.append(torch.arange(start_idx, end_idx, device=val.device))
935-
936-
# Concatenate indices for all chunks this rank should get
937-
indices = torch.cat(rank_slices)
938-
939-
# Select along sequence dimension (dim=1)
940-
return val.index_select(1, indices)
941-
942-
input_ids_padded = process_tensor_bshd(input_ids_padded)
943-
labels_padded = process_tensor_bshd(labels_padded)
990+
input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size)
991+
labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size)
944992

945993
else:
946994
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")

0 commit comments

Comments
 (0)