diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 1f903d4b..5828b3ea 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -487,16 +487,23 @@ def create_rft_command(args) -> int: return 1 # Build training config/body - # Ensure base model is explicitly provided for clarity - if not getattr(args, "base_model", None): - print( - "Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/)." - ) + # Exactly one of base-model or warm-start-from must be provided + base_model_raw = getattr(args, "base_model", None) + warm_start_from_raw = getattr(args, "warm_start_from", None) + # Treat empty/whitespace strings as not provided + base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw + warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw + has_base_model = bool(base_model) + has_warm_start = bool(warm_start_from) + if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start): + print("Error: exactly one of --base-model or --warm-start-from must be specified.") return 1 - training_config: Dict[str, Any] = {"baseModel": args.base_model} - if getattr(args, "warm_start_from", None): - training_config["warmStartFrom"] = args.warm_start_from + training_config: Dict[str, Any] = {} + if has_base_model: + training_config["baseModel"] = base_model + if has_warm_start: + training_config["warmStartFrom"] = warm_start_from # Optional hyperparameters for key, arg_name in [