Skip to content

Commit 6aca713

Browse files
committed
Removing handling max bs from client, handling in the REST API
1 parent 2e34944 commit 6aca713

5 files changed

Lines changed: 18 additions & 37 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,8 @@ def create(
304304
raise click.BadParameter(
305305
f"LoRA fine-tuning is not supported for the model `{model}`"
306306
)
307-
if training_method == "dpo":
308-
default_batch_size = model_limits.lora_training.max_batch_size_dpo
309-
else:
310-
default_batch_size = model_limits.lora_training.max_batch_size
311307
default_values = {
312308
"lora_r": model_limits.lora_training.max_rank,
313-
"batch_size": default_batch_size,
314309
"learning_rate": 1e-3,
315310
}
316311

@@ -335,15 +330,6 @@ def create(
335330
f"Please change the job type with --lora or remove `{param}` from the arguments"
336331
)
337332

338-
batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
339-
if batch_size_source == ParameterSource.DEFAULT:
340-
if training_method == "dpo":
341-
training_args["batch_size"] = (
342-
model_limits.full_training.max_batch_size_dpo
343-
)
344-
else:
345-
training_args["batch_size"] = model_limits.full_training.max_batch_size
346-
347333
if n_evals <= 0 and validation_file:
348334
log_warn(
349335
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."

src/together/legacy/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def create(
1616
model: str,
1717
n_epochs: int = 1,
1818
n_checkpoints: int | None = 1,
19-
batch_size: int | None = 32,
19+
batch_size: int | Literal["max"] = 32,
2020
learning_rate: float = 0.00001,
2121
suffix: (
2222
str | None
@@ -43,7 +43,7 @@ def create(
4343
model=model,
4444
n_epochs=n_epochs,
4545
n_checkpoints=n_checkpoints,
46-
batch_size=batch_size if isinstance(batch_size, int) else "max",
46+
batch_size=batch_size,
4747
learning_rate=learning_rate,
4848
suffix=suffix,
4949
wandb_api_key=wandb_api_key,

src/together/resources/finetune.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,22 @@ def create_finetune_request(
133133
min_batch_size = model_limits.full_training.min_batch_size
134134
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo
135135

136-
if batch_size == "max":
137-
if training_method == "dpo":
138-
batch_size = max_batch_size_dpo
139-
else:
140-
batch_size = max_batch_size
136+
if batch_size != "max":
137+
if training_method == "sft":
138+
if batch_size > max_batch_size:
139+
raise ValueError(
140+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
141+
)
142+
elif training_method == "dpo":
143+
if batch_size > max_batch_size_dpo:
144+
raise ValueError(
145+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
146+
)
141147

142-
if training_method == "sft":
143-
if batch_size > max_batch_size:
148+
if batch_size < min_batch_size:
144149
raise ValueError(
145-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
150+
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
146151
)
147-
elif training_method == "dpo":
148-
if batch_size > max_batch_size_dpo:
149-
raise ValueError(
150-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
151-
)
152-
153-
if batch_size < min_batch_size:
154-
raise ValueError(
155-
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
156-
)
157152

158153
if warmup_ratio > 1 or warmup_ratio < 0:
159154
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")

src/together/types/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class FinetuneRequest(BaseModel):
195195
# number of evaluation loops to run
196196
n_evals: int | None = None
197197
# training batch size
198-
batch_size: int | None = None
198+
batch_size: int | Literal["max"] | None = None
199199
# up to 40 character suffix for output model name
200200
suffix: str | None = None
201201
# weights & biases api key

tests/unit/test_finetune_resources.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_simple_request():
4444
assert request.n_epochs > 0
4545
assert request.warmup_ratio == 0.0
4646
assert request.training_type.type == "Full"
47-
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size
47+
assert request.batch_size == "max"
4848

4949

5050
def test_validation_file():
@@ -82,7 +82,7 @@ def test_lora_request():
8282
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
8383
assert request.training_type.lora_dropout == 0.0
8484
assert request.training_type.lora_trainable_modules == "all-linear"
85-
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
85+
assert request.batch_size == "max"
8686

8787

8888
@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])

0 commit comments

Comments
 (0)