Skip to content

Commit eec18ed

Browse files
committed
Change type hints to include str, convert to str
1 parent f70976b commit eec18ed

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

src/together/resources/finetune.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import List, Literal
5+
from typing import List, Literal, Union
66

77
from rich import print as rprint
88

@@ -570,7 +570,7 @@ def download(
570570
*,
571571
output: Path | str | None = None,
572572
checkpoint_step: int | None = None,
573-
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
573+
checkpoint_type: Union[DownloadCheckpointType, str] = DownloadCheckpointType.DEFAULT,
574574
) -> FinetuneDownloadResult:
575575
"""
576576
Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -583,7 +583,7 @@ def download(
583583
Defaults to None.
584584
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
585585
Defaults to -1 (download the final model)
586-
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
586+
checkpoint_type (Union[CheckpointType, str], optional): Specifies which checkpoint to download.
587587
Defaults to CheckpointType.DEFAULT.
588588
589589
Returns:
@@ -607,8 +607,12 @@ def download(
607607

608608
ft_job = self.retrieve(id)
609609

610+
# convert to str
611+
if isinstance(checkpoint_type, DownloadCheckpointType):
612+
checkpoint_type = checkpoint_type.value
613+
610614
if isinstance(ft_job.training_type, FullTrainingType):
611-
if checkpoint_type != DownloadCheckpointType.DEFAULT:
615+
if checkpoint_type != DownloadCheckpointType.DEFAULT.value:
612616
raise ValueError(
613617
"Only DEFAULT checkpoint type is allowed for FullTrainingType"
614618
)
@@ -617,10 +621,8 @@ def download(
617621
if checkpoint_type == DownloadCheckpointType.DEFAULT.value:
618622
checkpoint_type = DownloadCheckpointType.MERGED.value
619623

620-
if checkpoint_type == DownloadCheckpointType.MERGED.value:
621-
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
622-
elif checkpoint_type == DownloadCheckpointType.ADAPTER.value:
623-
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
624+
if checkpoint_type in {DownloadCheckpointType.MERGED.value, DownloadCheckpointType.ADAPTER.value}:
625+
url += f"&checkpoint={checkpoint_type}"
624626
else:
625627
raise ValueError(
626628
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"

0 commit comments

Comments
 (0)