Offline Distillation via DistillKit (Part Three - Offline Distillation)#1629
Offline Distillation via DistillKit (Part Three - Offline Distillation)#1629wolfecameron wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces offline knowledge-distillation training, including a new offline_distill.py script, a specialized data collator for compressed teacher logits, and a suite of distillation loss functions. Review feedback identifies a high-severity issue where the learning rate scheduler's total steps are incorrectly calculated when using multiple processes, as well as flawed validation logic for mutually exclusive dataset arguments. Other recommendations include optimizing the O(N^2) re-tokenization in dataset transformations, using device-safe tensor initialization in the collator, and removing redundant gathering operations in the training loop.
| args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes | ||
| ) | ||
| lr_scheduler = _create_scheduler(args, optimizer, num_training_steps_for_scheduler) |
There was a problem hiding this comment.
The calculation of num_training_steps_for_scheduler incorrectly multiplies max_train_steps by the number of processes when the user provides a fixed number of training steps. Since completed_steps tracks global optimization steps and the training loop runs for args.max_train_steps iterations, the scheduler should be initialized with args.max_train_steps to ensure the learning rate schedule aligns with the actual training duration. Multiplying by num_processes will cause the learning rate to decay much slower than intended.
num_training_steps_for_scheduler = args.max_train_steps| (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) | ||
| or (self.dataset_name is not None) | ||
| or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) | ||
| ): | ||
| raise ValueError("Cannot provide two dataset selection mechanisms.") |
There was a problem hiding this comment.
The logic for validating mutually exclusive dataset arguments is flawed. The condition or (self.dataset_name is not None) on line 168 will trigger a ValueError whenever dataset_name is provided, regardless of other arguments. Additionally, since dataset_mixer_list has a non-None default value (line 89), it will always conflict with dataset_name or dataset_mixer under these checks, making those arguments effectively unusable.
| total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item() | ||
| total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item() | ||
| local_total_tokens_this_log_period.zero_() | ||
| accelerator.gather(local_pred_tokens_this_log_period).sum().item() |
| tensor = feature[key] if isinstance(feature[key], torch.Tensor) else torch.tensor(feature[key]) | ||
| pad_len = max_distill_len - tensor.shape[0] | ||
| if pad_len > 0: | ||
| tensor = torch.cat([tensor, torch.zeros((pad_len, tensor.shape[1]), dtype=tensor.dtype)]) |
There was a problem hiding this comment.
Using torch.zeros without specifying a device can lead to a RuntimeError if the input tensor is already on a GPU. It is safer and more efficient to use tensor.new_zeros, which automatically matches the device and data type of the existing tensor.
| tensor = torch.cat([tensor, torch.zeros((pad_len, tensor.shape[1]), dtype=tensor.dtype)]) | |
| tensor = torch.cat([tensor, tensor.new_zeros((pad_len, tensor.shape[1]))]) |
| message_start_idx = len( | ||
| tokenizer.apply_chat_template( | ||
| conversation=messages[:message_idx], | ||
| tokenize=True, | ||
| add_generation_prompt=False, | ||
| return_dict=False, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Calling tokenizer.apply_chat_template inside a loop over messages to find start and end indices is inefficient. This results in
Summary
This PR is Part 3 of the offline distillation rollout for OpenInstruct (Part 1: #1525, Part 2: #1534).
It adds the student offline distillation training pipeline and supporting distillation loss/data components.
PR Scope
Adds offline distillation training + distillation-specific data/loss components:
Also updates pyproject.toml to include the new core distillation implementation files in the ty include list.
Attribution