Skip to content

Commit 2194d1d

Browse files
author
Tonny@Home
committed
fix: static prediction only can load previous model from source_record_id and save model for each prediction
1 parent e6c4a27 commit 2194d1d

2 files changed

Lines changed: 45 additions & 11 deletions

File tree

quantpits/scripts/static_train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def run_predict_only(args, targets):
367367
predict_single_model,
368368
print_model_table,
369369
make_model_key,
370+
resolve_model_key,
370371
PREDICTION_OUTPUT_DIR,
371372
RECORD_OUTPUT_FILE,
372373
)
@@ -388,8 +389,13 @@ def run_predict_only(args, targets):
388389

389390
# 检查哪些模型在源记录中存在
390391
source_models = source_records.get('models', {})
391-
available = {k: v for k, v in targets.items() if k in source_models}
392-
missing = {k: v for k, v in targets.items() if k not in source_models}
392+
available = {}
393+
missing = {}
394+
for k, v in targets.items():
395+
if resolve_model_key(k, source_models, default_mode='static'):
396+
available[k] = v
397+
else:
398+
missing[k] = v
393399

394400
if missing:
395401
print(f"\n⚠️ 以下模型不在源训练记录中,将跳过:")
@@ -466,15 +472,17 @@ def run_predict_only(args, targets):
466472
print("📊 Predict-Only 完成")
467473
print("=" * 60)
468474

469-
succeeded = [m for m in new_records['models']]
475+
succeeded = [m for m in available if m in new_performances]
470476
print(f" ✅ 成功: {len(succeeded)} 个模型")
471477
for name in succeeded:
472478
perf = new_performances.get(name, {})
473479
ic = perf.get('IC_Mean', 'N/A')
474480
icir = perf.get('ICIR', 'N/A')
475481
ic_str = f"{ic:.4f}" if isinstance(ic, float) else ic
476482
icir_str = f"{icir:.4f}" if isinstance(icir, float) else icir
477-
print(f" {name}: IC={ic_str}, ICIR={icir_str}")
483+
484+
model_key = make_model_key(name, 'static')
485+
print(f" {model_key}: IC={ic_str}, ICIR={icir_str}")
478486

479487
if failed_models:
480488
print(f" ❌ 失败: {len(failed_models)} 个模型")

quantpits/utils/train_utils.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,12 +1116,14 @@ def predict_single_model(model_name, model_info, params, experiment_name,
11161116

11171117
# 检查模型是否存在于源记录中
11181118
source_models = source_records.get('models', {})
1119-
if model_name not in source_models:
1119+
resolved_key = resolve_model_key(model_name, source_models, default_mode='static')
1120+
1121+
if not resolved_key:
11201122
result['error'] = f"模型 '{model_name}' 不在源训练记录中,无法加载已有模型"
11211123
print(f"!!! Error: {result['error']}")
11221124
return result
11231125

1124-
source_record_id = source_models[model_name]
1126+
source_record_id = source_models[resolved_key]
11251127
source_experiment = source_records.get('experiment_name', 'Weekly_Production_Train')
11261128

11271129
from qlib.utils import init_instance_by_config
@@ -1134,11 +1136,31 @@ def predict_single_model(model_name, model_info, params, experiment_name,
11341136
try:
11351137
# 1. 从源 recorder 加载模型
11361138
print(f"[{model_name}] Loading model from source recorder...")
1137-
source_recorder = R.get_recorder(
1138-
recorder_id=source_record_id,
1139-
experiment_name=source_experiment
1140-
)
1141-
model = source_recorder.load_object("model.pkl")
1139+
1140+
# 稳健加载:如果在当前 recorder 里没找到 model.pkl,则根据 source_record_id tag 向上溯源
1141+
current_id = source_record_id
1142+
current_exp = source_experiment
1143+
model = None
1144+
for _ in range(10):
1145+
source_recorder = R.get_recorder(
1146+
recorder_id=current_id,
1147+
experiment_name=current_exp
1148+
)
1149+
try:
1150+
model = source_recorder.load_object("model.pkl")
1151+
break
1152+
except Exception:
1153+
tags = source_recorder.list_tags()
1154+
if 'source_record_id' in tags and 'source_experiment' in tags:
1155+
print(f" [Fallback] model.pkl 不在 {current_id} 中,正在向上溯源到 {tags['source_record_id']}...")
1156+
current_id = tags['source_record_id']
1157+
current_exp = tags['source_experiment']
1158+
else:
1159+
raise ValueError(f"model.pkl not found in {current_id} and no parent tags available.")
1160+
1161+
if model is None:
1162+
raise ValueError(f"Exceeded max traceback depth of 10 for {model_name}.")
1163+
11421164
print(f"[{model_name}] Model loaded successfully")
11431165

11441166
# 2. 构建新的 dataset(使用新日期范围)
@@ -1176,6 +1198,10 @@ def predict_single_model(model_name, model_info, params, experiment_name,
11761198
r_obj = init_instance_by_config(r_cfg, recorder=recorder)
11771199
r_obj.generate()
11781200

1201+
# 重点修复:必须把模型也存入新的 recorder 里面,
1202+
# 否则下一个周期如果继续做仅预测,会因为上个仅预测的记录中没有 model.pkl 而失败
1203+
recorder.save_objects(**{"model.pkl": model})
1204+
11791205
# 获取 IC 指标
11801206
performance = {}
11811207
try:

0 commit comments

Comments
 (0)