Skip to content

Support DPO#130

Open
tastelikefeet wants to merge 14 commits intomodelscope:mainfrom
tastelikefeet:feat/dpo
Open

Support DPO#130
tastelikefeet wants to merge 14 commits intomodelscope:mainfrom
tastelikefeet:feat/dpo

Conversation

@tastelikefeet
Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Write the detail information belongs to this PR.

Experiment results

Paste your experiment result here(if needed).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Direct Preference Optimization (DPO) and its variants (SimPO, ORPO, CPO) to the Twinkle framework, adding new loss functions, specialized data preprocessors, and a Ray-based training recipe. The Trajectory data format was updated to include user_data, and the template encoding logic was enhanced with parallel processing. Feedback identifies critical issues such as a type mismatch in the template encoding return value and a hardcoded parameter that breaks reference-free loss modes. Additionally, logical errors in conversation parsing and multiple inconsistencies between documentation and implementation regarding default values, configurable keys, and supported loss types were noted, along with opportunities to improve the robustness of message role parsing and preprocessor outputs.

Comment on lines +329 to +330
trajectory.update(input_feature)
return trajectory
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _encode_messages method now updates the input trajectory with input_feature fields and returns the modified trajectory. However, the _invoke_post_pipeline (called by batch_encode) expects a List[InputFeature], not a List[Trajectory]. This type mismatch will cause a runtime error when _invoke_post_pipeline attempts to process Trajectory objects as InputFeature objects, as the internal pipeline functions like _check_max_length are designed for InputFeature.

Suggested change
trajectory.update(input_feature)
return trajectory
trajectory.update(input_feature)
return input_feature

reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo']

# Set up loss function
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The create_loss function is called with reference_free=False explicitly, even though reference_free is already determined by LOSS_TYPE on line 196. If LOSS_TYPE is one of the reference-free types (simpo, orpo, cpo), passing reference_free=False here will incorrectly force the loss function to expect a reference model, leading to a runtime error or incorrect behavior. The reference_free parameter should be passed dynamically based on the LOSS_TYPE evaluation.

Suggested change
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False)
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=reference_free)

Comment on lines +85 to +86
messages.append(Message(role=msg.get('role'), content=msg.get('content', '')))
return messages
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _build_prompt_messages, when parsing a dictionary message, msg.get('role') is called without a default value. If the 'role' key is missing from the dictionary, this will result in None being assigned as the role, which could lead to unexpected behavior or type errors downstream. It's best to provide a default value or handle the missing key explicitly.

Suggested change
messages.append(Message(role=msg.get('role'), content=msg.get('content', '')))
return messages
messages.append(Message(role=msg.get('role', 'user'), content=msg.get('content', '')))

Comment on lines +161 to +167
parts = text.split('\n\nHuman: ')
for i, part in enumerate(parts):
if i == 0 and not part.startswith('Human: '):
if part.strip():
if part.startswith('Human: '):
part = part[7:]
messages.append(Message(role='user', content=part.strip()))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling the first part of the conversation in _parse_hh_conversation contains a contradictory condition: if i == 0 and not part.startswith('Human: ') followed by if part.startswith('Human: '). This inner if statement will never be true if the outer if statement's condition not part.startswith('Human: ') is met. This indicates a logical error that might prevent correct parsing of certain conversation formats.

Comment on lines +35 to +36
BATCH_SIZE – global batch size (preference pairs) (default: 8)
MICRO_BATCH_SIZE – per-device micro batch size (default: 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default BATCH_SIZE and MICRO_BATCH_SIZE values specified in the docstring comments here (BATCH_SIZE default: 8, MICRO_BATCH_SIZE default: 2) are inconsistent with their actual default values set in the code (lines 76-77, BATCH_SIZE: 4, MICRO_BATCH_SIZE: 4). Please update the docstring to reflect the correct default values.

Comment on lines +44 to +47
Dataset field mapping (for custom datasets):
PROMPT_KEY – key for prompt field (default: 'prompt')
CHOSEN_KEY – key for chosen response (default: 'answer_zh')
REJECTED_KEY – key for rejected response (default: 'answer_en')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions PROMPT_KEY, CHOSEN_KEY, and REJECTED_KEY as configurable environment variables for custom datasets. However, the EmojiDPOProcessor used in create_dpo_dataset (lines 94-97) does not read these values from environment variables or accept them as init_args. This creates a discrepancy between the documented configurability and the actual implementation, making these environment variables ineffective for EmojiDPOProcessor.

beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1).
label_smoothing: Label smoothing parameter for soft labels (default: 0.0).
ignore_index: Index to ignore in labels (default: -100).
loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid').
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for DPOLoss lists 'kto_pair' as a supported loss_type. While the implementation for 'kto_pair' exists (line 228), this loss type is not exposed or used in the cookbook/rl/dpo.py script or dpo.sh script. This creates an inconsistency between the documentation and the practical usage examples provided.

Comment on lines +66 to +67
messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', '')))
return messages
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _parse_response, when a dictionary message is encountered, the role defaults to 'assistant' if not present. This might be too restrictive or incorrect if the message could originate from a user or another role. Consider making the default role more general or raising an error if the role is missing and cannot be inferred.

results = [self.preprocess(row) for row in rows]
results = [r for r in results if r is not None]
if not results:
return {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If results is empty, the __call__ method returns an empty dictionary {}. This can cause issues for downstream consumers expecting a dictionary with positive and negative keys, even if those lists are empty. Returning {'positive': [], 'negative': []} would provide a more consistent and predictable output structure.

        if not results:
            return {'positive': [], 'negative': []}

Comment on lines +418 to +420
chosen_key: Key for chosen response (default: 'answer_zh').
rejected_key: Key for rejected response (default: 'answer_en').
prompt_key: Key for prompt (default: 'prompt').
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The EmojiDPOProcessor hardcodes chosen_key, rejected_key, and prompt_key to specific values ('answer_zh', 'answer_en', 'prompt'). This contradicts the dpo.py script's docstring, which implies these keys can be configured via environment variables. To align with the documentation, these keys should either be configurable through init_args or the docstring should clarify that EmojiDPOProcessor uses fixed keys.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant