Skip to content

Commit ee7e02d

Browse files
committed
Add dpo to cli, fix typing mismatch with API
1 parent 3f1ec6a commit ee7e02d

4 files changed

Lines changed: 59 additions & 20 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ def fine_tuning(ctx: click.Context) -> None:
104104
default="all-linear",
105105
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
106106
)
107+
@click.option(
108+
"--training-method",
109+
type=click.Choice(["sft", "dpo"]),
110+
default="sft",
111+
help="Training method to use. Options: sft (supervised fine-tuning), dpo (direct preference optimization)",
112+
)
113+
@click.option(
114+
"--dpo-beta",
115+
type=float,
116+
default=0.1,
117+
help="Beta parameter for DPO training (only used when training-method is 'dpo')",
118+
)
107119
@click.option(
108120
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
109121
)
@@ -152,6 +164,8 @@ def create(
152164
wandb_name: str,
153165
confirm: bool,
154166
train_on_inputs: bool | Literal["auto"],
167+
training_method: str,
168+
dpo_beta: float,
155169
) -> None:
156170
"""Start fine-tuning"""
157171
client: Together = ctx.obj
@@ -180,6 +194,8 @@ def create(
180194
wandb_project_name=wandb_project_name,
181195
wandb_name=wandb_name,
182196
train_on_inputs=train_on_inputs,
197+
training_method=training_method,
198+
dpo_beta=dpo_beta,
183199
)
184200

185201
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(

src/together/resources/finetune.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import Literal
4+
from typing import Literal, Union
55

66
from rich import print as rprint
77

@@ -22,7 +22,8 @@
2222
TrainingType,
2323
FinetuneLRScheduler,
2424
FinetuneLinearLRSchedulerArgs,
25-
DPOTrainingMethodType,
25+
TrainingMethodDPO,
26+
TrainingMethodSFT,
2627
)
2728
from together.types.finetune import DownloadCheckpointType
2829
from together.utils import log_warn_once, normalize_key
@@ -108,10 +109,13 @@ def createFinetuneRequest(
108109
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
109110
)
110111

112+
training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = (
113+
TrainingMethodSFT()
114+
)
111115
if training_method == "dpo":
112-
training_method_args = DPOTrainingMethodType(dpo_beta=dpo_beta)
113-
else:
114-
training_method_args = None
116+
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
117+
118+
print("\n TRAINING METHOD at CREATE FINE TUNE REQUEST", training_method)
115119

116120
finetune_request = FinetuneRequest(
117121
model=model,
@@ -133,8 +137,7 @@ def createFinetuneRequest(
133137
wandb_project_name=wandb_project_name,
134138
wandb_name=wandb_name,
135139
train_on_inputs=train_on_inputs,
136-
training_method=training_method,
137-
training_method_args=training_method_args,
140+
training_method=training_method_cls,
138141
)
139142

140143
return finetune_request
@@ -173,7 +176,7 @@ def create(
173176
model_limits: FinetuneTrainingLimits | None = None,
174177
train_on_inputs: bool | Literal["auto"] = "auto",
175178
training_method: str = "sft",
176-
dpo_beta: float = 0.1,
179+
dpo_beta: float | None = None,
177180
) -> FinetuneResponse:
178181
"""
179182
Method to initiate a fine-tuning job
@@ -221,7 +224,7 @@ def create(
221224
Defaults to "auto".
222225
training_method (str, optional): Training method. Defaults to "sft".
223226
Supported methods: "sft", "dpo".
224-
dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1.
227+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
225228
226229
Returns:
227230
FinetuneResponse: Object containing information about fine-tuning job.
@@ -233,7 +236,7 @@ def create(
233236

234237
if model_limits is None:
235238
model_limits = self.get_model_limits(model=model)
236-
239+
print("\n DPO BETA at CREATE FINE TUNE REQUEST", dpo_beta)
237240
finetune_request = createFinetuneRequest(
238241
model_limits=model_limits,
239242
training_file=training_file,
@@ -268,6 +271,7 @@ def create(
268271
"Submitting a fine-tuning job with the following parameters:",
269272
finetune_request,
270273
)
274+
print("\n FINETUNE REQUEST before dump", finetune_request)
271275
parameter_payload = finetune_request.model_dump(exclude_none=True)
272276

273277
# Print the request payload before sending
@@ -525,7 +529,7 @@ async def create(
525529
model_limits: FinetuneTrainingLimits | None = None,
526530
train_on_inputs: bool | Literal["auto"] = "auto",
527531
training_method: str = "sft",
528-
dpo_beta: float = 0.1,
532+
dpo_beta: float | None = None,
529533
) -> FinetuneResponse:
530534
"""
531535
Async method to initiate a fine-tuning job
@@ -573,7 +577,7 @@ async def create(
573577
Defaults to "auto".
574578
training_method (str, optional): Training method. Defaults to "sft".
575579
Supported methods: "sft", "dpo".
576-
dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1.
580+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
577581
578582
Returns:
579583
FinetuneResponse: Object containing information about fine-tuning job.

src/together/types/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
FileType,
3232
)
3333
from together.types.finetune import (
34-
DPOTrainingMethodType,
34+
TrainingMethodDPO,
35+
TrainingMethodSFT,
3536
FinetuneDownloadResult,
3637
FinetuneLinearLRSchedulerArgs,
3738
FinetuneList,
@@ -80,7 +81,8 @@
8081
"TrainingType",
8182
"FullTrainingType",
8283
"LoRATrainingType",
83-
"DPOTrainingMethodType",
84+
"TrainingMethodDPO",
85+
"TrainingMethodSFT",
8486
"RerankRequest",
8587
"RerankResponse",
8688
"FinetuneTrainingLimits",

src/together/types/finetune.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal
4+
from typing import List, Literal, Union
55

66
from pydantic import StrictBool, Field, validator, field_validator
77

@@ -135,12 +135,29 @@ class LoRATrainingType(TrainingType):
135135
type: str = "Lora"
136136

137137

138-
class DPOTrainingMethodType(BaseModel):
138+
class TrainingMethod(BaseModel):
139+
"""
140+
Training method type
141+
"""
142+
143+
method: str
144+
145+
146+
class TrainingMethodSFT(TrainingMethod):
147+
"""
148+
Training method type for SFT training
149+
"""
150+
151+
method: str = "sft"
152+
153+
154+
class TrainingMethodDPO(TrainingMethod):
139155
"""
140156
Training method type for DPO training
141157
"""
142158

143-
dpo_beta: float
159+
method: str = "dpo"
160+
dpo_beta: float | None = None
144161

145162

146163
class FinetuneRequest(BaseModel):
@@ -187,9 +204,9 @@ class FinetuneRequest(BaseModel):
187204
# train on inputs
188205
train_on_inputs: StrictBool | Literal["auto"] = "auto"
189206
# training method
190-
training_method: str = "sft"
191-
# DPO params
192-
training_method_args: DPOTrainingMethodType | None = None
207+
training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field(
208+
default_factory=TrainingMethodSFT
209+
)
193210

194211

195212
class FinetuneResponse(BaseModel):

0 commit comments

Comments
 (0)