Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,16 +487,23 @@ def create_rft_command(args) -> int:
return 1

# Build training config/body
# Ensure base model is explicitly provided for clarity
if not getattr(args, "base_model", None):
print(
"Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/<model_id>)."
)
# Exactly one of base-model or warm-start-from must be provided
base_model_raw = getattr(args, "base_model", None)
warm_start_from_raw = getattr(args, "warm_start_from", None)
# Treat empty/whitespace strings as not provided
base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw
warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw
has_base_model = bool(base_model)
has_warm_start = bool(warm_start_from)
if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start):
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
return 1

training_config: Dict[str, Any] = {"baseModel": args.base_model}
if getattr(args, "warm_start_from", None):
training_config["warmStartFrom"] = args.warm_start_from
training_config: Dict[str, Any] = {}
if has_base_model:
training_config["baseModel"] = base_model
if has_warm_start:
training_config["warmStartFrom"] = warm_start_from

# Optional hyperparameters
for key, arg_name in [
Expand Down
Loading