@@ -730,6 +730,144 @@ def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_id
730730 return batch
731731
732732
733+ def _find_seq_dim (tensor : torch .Tensor , seq_len : int ) -> int :
734+ """Find which dimension of tensor matches the expected sequence length.
735+
736+ Args:
737+ tensor: The tensor to inspect.
738+ seq_len: The expected sequence length to match against tensor dimensions.
739+
740+ Returns:
741+ The dimension index that matches the sequence length.
742+
743+ Raises:
744+ ValueError: If no dimension matches the expected sequence length.
745+ """
746+ if tensor .ndim == 1 :
747+ if tensor .shape [0 ] == seq_len :
748+ return 0
749+ raise ValueError (f"1D tensor shape { tensor .shape } doesn't match sequence length { seq_len } " )
750+ elif tensor .ndim >= 2 :
751+ if tensor .shape [1 ] == seq_len :
752+ return 1
753+ elif tensor .shape [0 ] == seq_len :
754+ return 0
755+ raise ValueError (f"Tensor shape { tensor .shape } doesn't match sequence length { seq_len } in dim 0 or 1" )
756+ raise ValueError (f"Unexpected tensor ndim={ tensor .ndim } " )
757+
758+
759+ def _process_tensor_thd (
760+ val : torch .Tensor | None ,
761+ seq_len : int ,
762+ slice_sizes : torch .Tensor ,
763+ cu_seqlens_padded : torch .Tensor ,
764+ cp_rank : int ,
765+ total_slices : int ,
766+ ) -> torch .Tensor | None :
767+ """Extract the THD context-parallel shard for a single tensor.
768+
769+ For each sequence in the batch, selects two slices (one from the beginning and one from the end)
770+ corresponding to the given CP rank, following the zigzag CP sharding pattern.
771+
772+ Args:
773+ val: The tensor to shard, or None (returned as-is).
774+ seq_len: Total sequence length (from cu_seqlens_padded[-1]).
775+ slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices.
776+ cu_seqlens_padded: Cumulative sequence lengths including padding.
777+ cp_rank: The context parallelism rank index.
778+ total_slices: Total number of slices per sequence (2 * cp_world_size).
779+
780+ Returns:
781+ The sharded tensor for the given CP rank, or None if val is None.
782+ """
783+ if val is None :
784+ return val
785+
786+ seq_dim = _find_seq_dim (val , seq_len )
787+
788+ cp_rank_slices = []
789+ for slice_size , seq_start in zip (slice_sizes , cu_seqlens_padded [:- 1 ]):
790+ # 1st segment
791+ cp_rank_slices .append (
792+ torch .arange (
793+ seq_start + (cp_rank * slice_size ),
794+ seq_start + ((cp_rank + 1 ) * slice_size ),
795+ device = val .device ,
796+ )
797+ )
798+
799+ # 2nd segment
800+ cp_rank_slices .append (
801+ torch .arange (
802+ seq_start + ((total_slices - cp_rank - 1 ) * slice_size ),
803+ seq_start + ((total_slices - cp_rank ) * slice_size ),
804+ device = val .device ,
805+ )
806+ )
807+
808+ return val .index_select (seq_dim , torch .cat (cp_rank_slices ))
809+
810+
811+ def _process_tensor_bshd (
812+ val : torch .Tensor | None ,
813+ cp_rank : int ,
814+ cp_world_size : int ,
815+ ) -> torch .Tensor | None :
816+ """Extract the BSHD context-parallel shard for a single tensor.
817+
818+ Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks,
819+ then selects the two chunks corresponding to the given CP rank (zigzag pattern).
820+
821+ Args:
822+ val: The tensor to shard, or None (returned as-is).
823+ cp_rank: The context parallelism rank index.
824+ cp_world_size: Total number of context parallelism ranks.
825+
826+ Returns:
827+ The sharded tensor for the given CP rank, or None if val is None.
828+
829+ Raises:
830+ ValueError: If the tensor has fewer than 2 dimensions or its sequence length
831+ is not divisible by 2 * cp_world_size.
832+ """
833+ if val is None :
834+ return val
835+
836+ if val .ndim < 2 :
837+ raise ValueError (f"BSHD format requires at least 2D tensors, got { val .ndim } D" )
838+
839+ seq_len = val .shape [1 ]
840+
841+ # Calculate chunk size
842+ total_chunks = 2 * cp_world_size
843+ chunk_size = seq_len // total_chunks
844+
845+ if seq_len % total_chunks != 0 :
846+ raise ValueError (
847+ f"Sequence length { seq_len } must be divisible by { total_chunks } "
848+ f"(2 * cp_world_size) for BSHD context parallelism"
849+ )
850+
851+ # Determine which chunks this rank should get
852+ # Rank 0 gets chunks [0, total_chunks-1]
853+ # Rank 1 gets chunks [1, total_chunks-2]
854+ # Rank k gets chunks [k, total_chunks-k-1]
855+ chunk_indices = [cp_rank , total_chunks - cp_rank - 1 ]
856+
857+ # Collect slices for this rank
858+ rank_slices = []
859+ for chunk_idx in chunk_indices :
860+ start_idx = chunk_idx * chunk_size
861+ end_idx = start_idx + chunk_size
862+ rank_slices .append (torch .arange (start_idx , end_idx , device = val .device ))
863+
864+ # Concatenate indices for all chunks this rank should get
865+ indices = torch .cat (rank_slices )
866+
867+ # Select along sequence dimension (dim=1)
868+ return val .index_select (1 , indices )
869+
870+
733871def _pt_pad_to_multiple_of (batch : dict [str , Any ], pad_to_multiple_of : int , token_pad : int , label_pad : int ):
734872 """Pad a batch to a multiple of pad_to_multiple_of.
735873
@@ -837,110 +975,20 @@ def _split_batch_by_cp_rank(
837975 total_slices_of_any_sequence = 2 * cp_world_size
838976 slice_sizes = (cu_seqlens_padded [1 :] - cu_seqlens_padded [:- 1 ]) // total_slices_of_any_sequence
839977
840- # Process each tensor directly instead of using keys_to_change loop
841- def process_tensor (val ):
842- if val is None :
843- return val
844- # Determine which dimension is the sequence dimension
845- # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
846- if isinstance (cu_seqlens_padded [- 1 ], torch .Tensor ):
847- seq_len_val = cu_seqlens_padded [- 1 ].item ()
848- else :
849- seq_len_val = cu_seqlens_padded [- 1 ]
850-
851- # Handle 1D tensors (like position_ids that don't have batch dimension)
852- if val .ndim == 1 :
853- if val .shape [0 ] == seq_len_val :
854- current_seq_dim = 0
855- else :
856- raise ValueError (
857- "1D tensor shape doesn't match expected sequence length. Make sure the"
858- " inputs are in THD format and padded correctly."
859- )
860- elif val .ndim >= 2 :
861- if val .shape [1 ] == seq_len_val :
862- current_seq_dim = 1
863- elif val .shape [0 ] == seq_len_val :
864- current_seq_dim = 0
865- else :
866- raise ValueError ("Make sure the inputs are in THD format and padded correctly." )
867- else :
868- raise ValueError ("Tensor must be at least 1D" )
869-
870- # On this particular rank, for each sequence, get two slices, one from the beginning
871- # and one from the end.
872- cp_rank_slices = []
873- for slice_size , seq_start in zip (slice_sizes , cu_seqlens_padded [:- 1 ]):
874- # 1st segment
875- cp_rank_slices .append (
876- torch .arange (
877- seq_start + (cp_rank * slice_size ),
878- seq_start + ((cp_rank + 1 ) * slice_size ),
879- device = val .device ,
880- )
881- )
978+ # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
979+ last_elem = cu_seqlens_padded [- 1 ]
980+ seq_len_val = last_elem .item () if isinstance (last_elem , torch .Tensor ) else last_elem
882981
883- # 2nd segment
884- cp_rank_slices .append (
885- torch .arange (
886- seq_start + ((total_slices_of_any_sequence - cp_rank - 1 ) * slice_size ),
887- seq_start + ((total_slices_of_any_sequence - cp_rank ) * slice_size ),
888- device = val .device ,
889- )
890- )
891-
892- return val .index_select (current_seq_dim , torch .cat (cp_rank_slices ))
893-
894- # Process each tensor directly
895- input_ids_padded = process_tensor (input_ids_padded )
896- labels_padded = process_tensor (labels_padded )
982+ input_ids_padded = _process_tensor_thd (
983+ input_ids_padded , seq_len_val , slice_sizes , cu_seqlens_padded , cp_rank , total_slices_of_any_sequence
984+ )
985+ labels_padded = _process_tensor_thd (
986+ labels_padded , seq_len_val , slice_sizes , cu_seqlens_padded , cp_rank , total_slices_of_any_sequence
987+ )
897988
898989 elif qvk_format == "bshd" :
899- # BSHD format: [batch, seq_len, ...]
900- # Split along sequence dimension (dim=1)
901- # Each sequence is split into 2*cp_world_size chunks
902- # Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1]
903-
904- def process_tensor_bshd (val ):
905- if val is None :
906- return val
907-
908- if val .ndim < 2 :
909- raise ValueError (f"BSHD format requires at least 2D tensors, got { val .ndim } D" )
910-
911- seq_len = val .shape [1 ]
912-
913- # Calculate chunk size
914- total_chunks = 2 * cp_world_size
915- chunk_size = seq_len // total_chunks
916-
917- if chunk_size == 0 :
918- raise ValueError (
919- f"Sequence length { seq_len } must be divisible by { total_chunks } "
920- f"(2 * cp_world_size) for BSHD context parallelism"
921- )
922-
923- # Determine which chunks this rank should get
924- # Rank 0 gets chunks [0, total_chunks-1]
925- # Rank 1 gets chunks [1, total_chunks-2]
926- # Rank k gets chunks [k, total_chunks-k-1]
927- chunk_indices = [cp_rank , total_chunks - cp_rank - 1 ]
928-
929- # Collect slices for this rank
930- rank_slices = []
931- for chunk_idx in chunk_indices :
932- start_idx = chunk_idx * chunk_size
933- end_idx = start_idx + chunk_size
934- rank_slices .append (torch .arange (start_idx , end_idx , device = val .device ))
935-
936- # Concatenate indices for all chunks this rank should get
937- indices = torch .cat (rank_slices )
938-
939- # Select along sequence dimension (dim=1)
940- return val .index_select (1 , indices )
941-
942- input_ids_padded = process_tensor_bshd (input_ids_padded )
943- labels_padded = process_tensor_bshd (labels_padded )
990+ input_ids_padded = _process_tensor_bshd (input_ids_padded , cp_rank , cp_world_size )
991+ labels_padded = _process_tensor_bshd (labels_padded , cp_rank , cp_world_size )
944992
945993 else :
946994 raise ValueError (f"Support not implemented yet for qvk_format: { qvk_format } !" )
0 commit comments