Skip to content

Commit 0eeafe6

Browse files
committed
fix the base model requirement
1 parent fe39fd8 commit 0eeafe6

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -487,16 +487,18 @@ def create_rft_command(args) -> int:
487487
return 1
488488

489489
# Build training config/body
490-
# Ensure base model is explicitly provided for clarity
491-
if not getattr(args, "base_model", None):
492-
print(
493-
"Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/<model_id>)."
494-
)
490+
# Exactly one of base-model or warm-start-from must be provided
491+
base_model = getattr(args, "base_model", None)
492+
warm_start_from = getattr(args, "warm_start_from", None)
493+
if (base_model is None and warm_start_from is None) or (base_model is not None and warm_start_from is not None):
494+
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
495495
return 1
496496

497-
training_config: Dict[str, Any] = {"baseModel": args.base_model}
498-
if getattr(args, "warm_start_from", None):
499-
training_config["warmStartFrom"] = args.warm_start_from
497+
training_config: Dict[str, Any] = {}
498+
if base_model is not None:
499+
training_config["baseModel"] = base_model
500+
if warm_start_from is not None:
501+
training_config["warmStartFrom"] = warm_start_from
500502

501503
# Optional hyperparameters
502504
for key, arg_name in [

0 commit comments

Comments
 (0)