Skip to content

Commit 90d5a6d

Browse files
committed
Cast to enum instead of str
1 parent eec18ed commit 90d5a6d

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/together/resources/finetune.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -608,21 +608,21 @@ def download(
608608
ft_job = self.retrieve(id)
609609

610610
# convert to str
611-
if isinstance(checkpoint_type, DownloadCheckpointType):
612-
checkpoint_type = checkpoint_type.value
611+
if isinstance(checkpoint_type, str):
612+
checkpoint_type = DownloadCheckpointType(checkpoint_type)
613613

614614
if isinstance(ft_job.training_type, FullTrainingType):
615-
if checkpoint_type != DownloadCheckpointType.DEFAULT.value:
615+
if checkpoint_type != DownloadCheckpointType.DEFAULT:
616616
raise ValueError(
617617
"Only DEFAULT checkpoint type is allowed for FullTrainingType"
618618
)
619619
url += "&checkpoint=model_output_path"
620620
elif isinstance(ft_job.training_type, LoRATrainingType):
621-
if checkpoint_type == DownloadCheckpointType.DEFAULT.value:
622-
checkpoint_type = DownloadCheckpointType.MERGED.value
621+
if checkpoint_type == DownloadCheckpointType.DEFAULT:
622+
checkpoint_type = DownloadCheckpointType.MERGED
623623

624-
if checkpoint_type in {DownloadCheckpointType.MERGED.value, DownloadCheckpointType.ADAPTER.value}:
625-
url += f"&checkpoint={checkpoint_type}"
624+
if checkpoint_type in {DownloadCheckpointType.MERGED, DownloadCheckpointType.ADAPTER}:
625+
url += f"&checkpoint={checkpoint_type.value}"
626626
else:
627627
raise ValueError(
628628
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"

0 commit comments

Comments
 (0)