22
33import re
44from pathlib import Path
5- from typing import List , Literal
5+ from typing import List , Literal , Union
66
77from rich import print as rprint
88
@@ -570,7 +570,7 @@ def download(
570570 * ,
571571 output : Path | str | None = None ,
572572 checkpoint_step : int | None = None ,
573- checkpoint_type : DownloadCheckpointType = DownloadCheckpointType .DEFAULT ,
573+ checkpoint_type : Union [ DownloadCheckpointType , str ] = DownloadCheckpointType .DEFAULT ,
574574 ) -> FinetuneDownloadResult :
575575 """
576576 Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -583,7 +583,7 @@ def download(
583583 Defaults to None.
584584 checkpoint_step (int, optional): Specifies step number for checkpoint to download.
585585 Defaults to -1 (download the final model)
586- checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
586+ checkpoint_type (Union[ CheckpointType, str] , optional): Specifies which checkpoint to download.
587587 Defaults to CheckpointType.DEFAULT.
588588
589589 Returns:
@@ -607,8 +607,12 @@ def download(
607607
608608 ft_job = self .retrieve (id )
609609
610+ # convert to str
611+ if isinstance (checkpoint_type , DownloadCheckpointType ):
612+ checkpoint_type = checkpoint_type .value
613+
610614 if isinstance (ft_job .training_type , FullTrainingType ):
611- if checkpoint_type != DownloadCheckpointType .DEFAULT :
615+ if checkpoint_type != DownloadCheckpointType .DEFAULT . value :
612616 raise ValueError (
613617 "Only DEFAULT checkpoint type is allowed for FullTrainingType"
614618 )
@@ -617,10 +621,8 @@ def download(
617621 if checkpoint_type == DownloadCheckpointType .DEFAULT .value :
618622 checkpoint_type = DownloadCheckpointType .MERGED .value
619623
620- if checkpoint_type == DownloadCheckpointType .MERGED .value :
621- url += f"&checkpoint={ DownloadCheckpointType .MERGED .value } "
622- elif checkpoint_type == DownloadCheckpointType .ADAPTER .value :
623- url += f"&checkpoint={ DownloadCheckpointType .ADAPTER .value } "
624+ if checkpoint_type in {DownloadCheckpointType .MERGED .value , DownloadCheckpointType .ADAPTER .value }:
625+ url += f"&checkpoint={ checkpoint_type } "
624626 else :
625627 raise ValueError (
626628 f"Invalid checkpoint type for LoRATrainingType: { checkpoint_type } "
0 commit comments