Skip to content

Offline Distillation via DistillKit (Part Three - Offline Distillation)#1629

Open
wolfecameron wants to merge 6 commits into
allenai:mainfrom
wolfecameron:distill-pr03
Open

Offline Distillation via DistillKit (Part Three - Offline Distillation)#1629
wolfecameron wants to merge 6 commits into
allenai:mainfrom
wolfecameron:distill-pr03

Conversation

@wolfecameron
Copy link
Copy Markdown
Contributor

Summary

This PR is Part 3 of the offline distillation rollout for OpenInstruct (Part 1: #1525, Part 2: #1534).
It adds the student offline distillation training pipeline and supporting distillation loss/data components.

PR Scope

Adds offline distillation training + distillation-specific data/loss components:

  • open_instruct/offline_distill.py
    • training entrypoint for student distillation from precomputed teacher logits
    • integrates distillation loss computation into the existing training/checkpoint/eval flow
  • open_instruct/dataset_transformation.py
    • adds distill_pretokenized_v1 and distill_pretokenized_filter_v1
    • adds distillation dataset key definitions/target columns
  • open_instruct/distillation_collator.py
    • DistillationDataCollator for batching/padding compressed teacher signal tensors
  • open_instruct/distillkit/distill_loss.py + open_instruct/distillkit/signals.py
    • distillation loss interface
  • open_instruct/distillkit/lossfuncs/*
    • distillation loss stack (KL/JSD/TVD/Hinge/Logistic Ranking + CE integration)
  • open_instruct/distillkit/test_offline_distill_utils.py
    • unittest coverage for transforms, collator behavior, and loss wiring
  • scripts/train/distill/olmo3_7b_student_offline_distill.sh
    • example launch script for student offline distillation

Also updates pyproject.toml to include the new core distillation implementation files in the ty include list.

Attribution

  • Contribution work from Netflix Research: Cameron Wolfe (cameronwolfe@netflix.com)
  • DistillKit adaptation: Portions of the distillation loss/signal implementation were adapted from DistillKit, with attribution headers preserved in relevant files.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces offline knowledge-distillation training, including a new offline_distill.py script, a specialized data collator for compressed teacher logits, and a suite of distillation loss functions. Review feedback identifies a high-severity issue where the learning rate scheduler's total steps are incorrectly calculated when using multiple processes, as well as flawed validation logic for mutually exclusive dataset arguments. Other recommendations include optimizing the O(N^2) re-tokenization in dataset transformations, using device-safe tensor initialization in the collator, and removing redundant gathering operations in the training loop.

Comment on lines +481 to +483
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)
lr_scheduler = _create_scheduler(args, optimizer, num_training_steps_for_scheduler)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of num_training_steps_for_scheduler incorrectly multiplies max_train_steps by the number of processes when the user provides a fixed number of training steps. Since completed_steps tracks global optimization steps and the training loop runs for args.max_train_steps iterations, the scheduler should be initialized with args.max_train_steps to ensure the learning rate schedule aligns with the actual training duration. Multiplying by num_processes will cause the learning rate to decay much slower than intended.

    num_training_steps_for_scheduler = args.max_train_steps

Comment on lines +166 to +170
(self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None))
or (self.dataset_name is not None)
or (self.dataset_mixer is not None and self.dataset_mixer_list is not None)
):
raise ValueError("Cannot provide two dataset selection mechanisms.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for validating mutually exclusive dataset arguments is flawed. The condition or (self.dataset_name is not None) on line 168 will trigger a ValueError whenever dataset_name is provided, regardless of other arguments. Additionally, since dataset_mixer_list has a non-None default value (line 89), it will always conflict with dataset_name or dataset_mixer under these checks, making those arguments effectively unusable.

total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
local_total_tokens_this_log_period.zero_()
accelerator.gather(local_pred_tokens_this_log_period).sum().item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line performs a collective gather operation but discards the result. It appears to be a typo or redundant code, as the gathered value is not assigned to any variable or used in the subsequent logging logic.

tensor = feature[key] if isinstance(feature[key], torch.Tensor) else torch.tensor(feature[key])
pad_len = max_distill_len - tensor.shape[0]
if pad_len > 0:
tensor = torch.cat([tensor, torch.zeros((pad_len, tensor.shape[1]), dtype=tensor.dtype)])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.zeros without specifying a device can lead to a RuntimeError if the input tensor is already on a GPU. It is safer and more efficient to use tensor.new_zeros, which automatically matches the device and data type of the existing tensor.

Suggested change
tensor = torch.cat([tensor, torch.zeros((pad_len, tensor.shape[1]), dtype=tensor.dtype)])
tensor = torch.cat([tensor, tensor.new_zeros((pad_len, tensor.shape[1]))])

Comment on lines +1584 to +1591
message_start_idx = len(
tokenizer.apply_chat_template(
conversation=messages[:message_idx],
tokenize=True,
add_generation_prompt=False,
return_dict=False,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling tokenizer.apply_chat_template inside a loop over messages to find start and end indices is inefficient. This results in $O(N^2)$ complexity relative to the number of messages because the prefix is re-tokenized for every turn. For long conversations, this can significantly slow down dataset preparation. Consider tokenizing the entire conversation once and tracking offsets or using special tokens to identify boundaries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant