Skip to content

Commit 61ca3d7

Browse files
committed
tighten validation
1 parent 0eeafe6 commit 61ca3d7

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -488,16 +488,21 @@ def create_rft_command(args) -> int:
488488

489489
# Build training config/body
490490
# Exactly one of base-model or warm-start-from must be provided
491-
base_model = getattr(args, "base_model", None)
492-
warm_start_from = getattr(args, "warm_start_from", None)
493-
if (base_model is None and warm_start_from is None) or (base_model is not None and warm_start_from is not None):
491+
base_model_raw = getattr(args, "base_model", None)
492+
warm_start_from_raw = getattr(args, "warm_start_from", None)
493+
# Treat empty/whitespace strings as not provided
494+
base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw
495+
warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw
496+
has_base_model = bool(base_model)
497+
has_warm_start = bool(warm_start_from)
498+
if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start):
494499
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
495500
return 1
496501

497502
training_config: Dict[str, Any] = {}
498-
if base_model is not None:
503+
if has_base_model:
499504
training_config["baseModel"] = base_model
500-
if warm_start_from is not None:
505+
if has_warm_start:
501506
training_config["warmStartFrom"] = warm_start_from
502507

503508
# Optional hyperparameters

0 commit comments

Comments
 (0)