11from __future__ import annotations
22
33from pathlib import Path
4- from typing import Literal
4+ from typing import Literal , Union
55
66from rich import print as rprint
77
2222 TrainingType ,
2323 FinetuneLRScheduler ,
2424 FinetuneLinearLRSchedulerArgs ,
25- DPOTrainingMethodType ,
25+ TrainingMethodDPO ,
26+ TrainingMethodSFT ,
2627)
2728from together .types .finetune import DownloadCheckpointType
2829from together .utils import log_warn_once , normalize_key
@@ -108,10 +109,13 @@ def createFinetuneRequest(
108109 lr_scheduler_args = FinetuneLinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
109110 )
110111
112+ training_method_cls : Union [TrainingMethodSFT , TrainingMethodDPO ] = (
113+ TrainingMethodSFT ()
114+ )
111115 if training_method == "dpo" :
112- training_method_args = DPOTrainingMethodType (dpo_beta = dpo_beta )
113- else :
114- training_method_args = None
116+ training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
117+
118+ print ( " \n TRAINING METHOD at CREATE FINE TUNE REQUEST" , training_method )
115119
116120 finetune_request = FinetuneRequest (
117121 model = model ,
@@ -133,8 +137,7 @@ def createFinetuneRequest(
133137 wandb_project_name = wandb_project_name ,
134138 wandb_name = wandb_name ,
135139 train_on_inputs = train_on_inputs ,
136- training_method = training_method ,
137- training_method_args = training_method_args ,
140+ training_method = training_method_cls ,
138141 )
139142
140143 return finetune_request
@@ -173,7 +176,7 @@ def create(
173176 model_limits : FinetuneTrainingLimits | None = None ,
174177 train_on_inputs : bool | Literal ["auto" ] = "auto" ,
175178 training_method : str = "sft" ,
176- dpo_beta : float = 0.1 ,
179+ dpo_beta : float | None = None ,
177180 ) -> FinetuneResponse :
178181 """
179182 Method to initiate a fine-tuning job
@@ -221,7 +224,7 @@ def create(
221224 Defaults to "auto".
222225 training_method (str, optional): Training method. Defaults to "sft".
223226 Supported methods: "sft", "dpo".
224- dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1 .
227+ dpo_beta (float, optional): DPO beta parameter. Defaults to None .
225228
226229 Returns:
227230 FinetuneResponse: Object containing information about fine-tuning job.
@@ -233,7 +236,7 @@ def create(
233236
234237 if model_limits is None :
235238 model_limits = self .get_model_limits (model = model )
236-
239+ print ( " \n DPO BETA at CREATE FINE TUNE REQUEST" , dpo_beta )
237240 finetune_request = createFinetuneRequest (
238241 model_limits = model_limits ,
239242 training_file = training_file ,
@@ -268,6 +271,7 @@ def create(
268271 "Submitting a fine-tuning job with the following parameters:" ,
269272 finetune_request ,
270273 )
274+ print ("\n FINETUNE REQUEST before dump" , finetune_request )
271275 parameter_payload = finetune_request .model_dump (exclude_none = True )
272276
273277 # Print the request payload before sending
@@ -525,7 +529,7 @@ async def create(
525529 model_limits : FinetuneTrainingLimits | None = None ,
526530 train_on_inputs : bool | Literal ["auto" ] = "auto" ,
527531 training_method : str = "sft" ,
528- dpo_beta : float = 0.1 ,
532+ dpo_beta : float | None = None ,
529533 ) -> FinetuneResponse :
530534 """
531535 Async method to initiate a fine-tuning job
@@ -573,7 +577,7 @@ async def create(
573577 Defaults to "auto".
574578 training_method (str, optional): Training method. Defaults to "sft".
575579 Supported methods: "sft", "dpo".
576- dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1 .
580+ dpo_beta (float, optional): DPO beta parameter. Defaults to None .
577581
578582 Returns:
579583 FinetuneResponse: Object containing information about fine-tuning job.
0 commit comments