@@ -154,8 +154,11 @@ def __call__(self, features, return_tensors=None):
154154 sequence processing capabilities. When pad_to_multiple_of is used, an additional
155155 mock sequence is appended to reach the desired total length.
156156 """
157+ if return_tensors is not None and return_tensors != "pt" :
158+ raise NotImplementedError (f"Only return_tensors='pt' is supported, got '{ return_tensors } '" )
159+
157160 # Perform the masking with the BSHD collator.
158- bshd_batch = self .collator (features )
161+ bshd_batch = self .collator (features , return_tensors = return_tensors )
159162
160163 # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values.
161164 packed_batch = _pt_flatten_collate (features , return_position_ids = self .return_position_ids )
@@ -247,29 +250,68 @@ def __iter__(self):
247250 samples = []
248251 current_length = 0
249252 for sample in iter (self .dataset ):
250- current_length += len (sample ["input_ids" ])
253+ sample_length = len (sample ["input_ids" ])
254+ current_length += sample_length
255+
251256 if current_length == self .max_tokens_per_batch :
252257 yield [* samples , sample ]
253258 samples = []
254259 current_length = 0
255260
256261 elif current_length > self .max_tokens_per_batch :
257- if not self .split_samples :
258- # If we are not splitting samples, we can just yield the current batch (before this sample) and
259- # start a new one.
260- yield samples
261- samples = [sample ]
262+ tokens_available = self .max_tokens_per_batch - (current_length - sample_length )
263+
264+ if tokens_available <= 0 :
265+ # Current batch is already full (or over); yield it first, then handle this sample.
266+ if samples :
267+ yield samples
268+ samples = []
269+ current_length = sample_length
270+ tokens_available = self .max_tokens_per_batch
271+
272+ # Now handle the incoming sample with a fresh batch.
273+ if sample_length == self .max_tokens_per_batch :
274+ yield [sample ]
275+ samples = []
276+ current_length = 0
277+ continue
278+ elif sample_length < self .max_tokens_per_batch :
279+ samples = [sample ]
280+ continue
281+ # sample_length > max_tokens_per_batch: fall through to split logic below
262282
283+ if not self .split_samples :
284+ # Yield the current batch (before this sample) and start a new one with this sample.
285+ if samples :
286+ yield samples
287+ # The sample itself may exceed max_tokens_per_batch; yield it as its own batch.
288+ if sample_length > self .max_tokens_per_batch :
289+ yield [sample ]
290+ samples = []
291+ current_length = 0
292+ else :
293+ samples = [sample ]
294+ current_length = sample_length
263295 else :
264- # Calculate how many tokens are already in the batch
265- tokens_in_batch = current_length - len (sample ["input_ids" ])
266- # Calculate how many tokens we can fit from this sample
267- tokens_available = self .max_tokens_per_batch - tokens_in_batch
268- first_part , remaining_part = _split_sample_by_num_tokens (sample , tokens_available )
269- yield [* samples , first_part ]
270- samples = [remaining_part ]
271-
272- current_length = len (samples [0 ]["input_ids" ])
296+ # Split mode: fill the current batch, then split remaining into chunks.
297+ if tokens_available > 0 and tokens_available < sample_length :
298+ first_part , remaining = _split_sample_by_num_tokens (sample , tokens_available )
299+ yield [* samples , first_part ]
300+ else :
301+ # tokens_available >= sample_length shouldn't happen here, but guard anyway
302+ if samples :
303+ yield samples
304+ remaining = sample
305+
306+ # Now split the remaining part into chunks of max_tokens_per_batch.
307+ while len (remaining ["input_ids" ]) > self .max_tokens_per_batch :
308+ chunk , remaining = _split_sample_by_num_tokens (remaining , self .max_tokens_per_batch )
309+ yield [chunk ]
310+
311+ samples = [remaining ]
312+ current_length = len (remaining ["input_ids" ])
313+ continue
314+
273315 else :
274316 samples .append (sample )
275317
@@ -345,7 +387,8 @@ def __call__(self, features) -> list[dict[str, Any]]:
345387 else :
346388 raise ValueError (f"Unsupported qvk_format: { self .qkv_format } !" )
347389
348- batch_shard ["max_length_k" ] = batch_shard ["max_length_q" ] = max_length * round (max_length / 64 )
390+ padded_max = ((max_length + 63 ) // 64 ) * 64
391+ batch_shard ["max_length_k" ] = batch_shard ["max_length_q" ] = padded_max
349392 combined_batch .append (batch_shard )
350393
351394 return combined_batch
0 commit comments