From 4e6cd771d6878e42b3c40c643653e59f1e5b356d Mon Sep 17 00:00:00 2001 From: chenminghui1 <2876747425@qq.com> Date: Thu, 7 May 2026 09:39:16 +0800 Subject: [PATCH] feat: add dual-model workflow notebook with LGBModel and ALSTM Add examples/my_workflow_visual.ipynb demonstrating a complete quant research pipeline with LGBModel + ALSTM side-by-side comparison: - Dynamic date cutoff (3 PM rule: after market close uses today, otherwise yesterday) - Dual model training and prediction (LGBModel with Alpha158, ALSTM with Alpha360) - Dual backtest comparison with cumulative return curves - Auto-selects best-performing model for final stock recommendations - Chinese font support for matplotlib charts Also removes the *.ipynb gitignore rule so notebooks can be tracked. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 1 - examples/my_workflow_visual.ipynb | 1030 +++++++++++++++++++++++++++++ qlib/backtest/exchange.py | 6 +- 3 files changed, 1033 insertions(+), 4 deletions(-) create mode 100644 examples/my_workflow_visual.ipynb diff --git a/.gitignore b/.gitignore index ffc592ccd52..7a00dcaba2c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ __pycache__/ *.pyc *.pyd *.so -*.ipynb .ipynb_checkpoints _build build/ diff --git a/examples/my_workflow_visual.ipynb b/examples/my_workflow_visual.ipynb new file mode 100644 index 00000000000..ea1995bf23b --- /dev/null +++ b/examples/my_workflow_visual.ipynb @@ -0,0 +1,1030 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Copyright (c) Microsoft Corporation.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "0cbb10aa", + "metadata": {}, + "outputs": [], + "source": [ + "import sys, site\n", + "from pathlib import Path\n", + "\n", + "################################# NOTE #################################\n", + "# Please be aware that if colab installs the latest numpy and pyqlib #\n", + "# in this cell, users should RESTART the runtime in order to run the #\n", + "# following cells successfully. #\n", + "########################################################################\n", + "\n", + "try:\n", + " import qlib\n", + "except ImportError:\n", + " # install qlib\n", + " ! pip install --upgrade numpy\n", + " ! pip install pyqlib\n", + " if \"google.colab\" in sys.modules:\n", + " ! pip install pyyaml==5.4.1\n", + " # reload\n", + " site.main()\n", + "\n", + "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n", + "if not scripts_dir.joinpath(\"get_data.py\").exists():\n", + " # download get_data.py script\n", + " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n", + " scripts_dir.mkdir(parents=True, exist_ok=True)\n", + " import requests\n", + "\n", + " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n", + " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n", + " fp.write(resp.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "e3239f51", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "今天: 2026-05-07\n", + "当前时间: 08:32:27\n", + "数据截止日期: 2026-05-06\n" + ] + } + ], + "source": [ + "import qlib\n", + "import pandas as pd\n", + "import numpy as np\n", + "from datetime import datetime, timedelta\n", + "from qlib.constant import REG_CN\n", + "from qlib.utils import exists_qlib_data, init_instance_by_config\n", + "from qlib.workflow import R\n", + "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n", + "from qlib.utils import flatten_dict\n", + "from qlib.data import D\n", + "from qlib.backtest import backtest\n", + "from qlib.backtest.executor import SimulatorExecutor\n", + "from qlib.contrib.strategy import TopkDropoutStrategy\n", + "from qlib.contrib.evaluate import risk_analysis\n", + "from qlib.utils.time import Freq\n", + "\n", + "# matplotlib 中文字体\n", + "import matplotlib.pyplot as plt\n", + "plt.rcParams[\"font.sans-serif\"] = [\"SimHei\", \"Microsoft YaHei\", \"DejaVu Sans\"]\n", + "plt.rcParams[\"axes.unicode_minus\"] = False\n", + "\n", + "# 动态计算数据截止日期:三点之后用今天,否则用昨天\n", + "now = datetime.now()\n", + "today = now.date()\n", + "if now.hour >= 15:\n", + " cutoff_date = today\n", + "else:\n", + " cutoff_date = today - timedelta(days=1)\n", + "cutoff_str = cutoff_date.strftime(\"%Y-%m-%d\")\n", + "\n", + "print(f\"今天: {today}\")\n", + "print(f\"当前时间: {now.strftime('%H:%M:%S')}\")\n", + "print(f\"数据截止日期: {cutoff_str}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "6de80a30", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:32:28,691) INFO - qlib.Initialization - [config.py:453] - default_conf: client.\n", + "[52360:MainThread](2026-05-07 08:32:28,695) INFO - qlib.Initialization - [__init__.py:82] - qlib successfully initialized based on client settings.\n", + "[52360:MainThread](2026-05-07 08:32:28,697) INFO - qlib.Initialization - [__init__.py:84] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/chen/.qlib/qlib_data/cn_data')}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data calendar range: 2000-01-04 00:00:00 ~ 2026-04-30 00:00:00\n", + "Number of trading days available: 6378\n" + ] + } + ], + "source": [ + "# Download Qlib data (will fetch latest available data)\n", + "provider_uri = \"~/.qlib/qlib_data/cn_data\"\n", + "if not exists_qlib_data(provider_uri):\n", + " print(f\"Qlib data is not found in {provider_uri}\")\n", + " sys.path.append(str(scripts_dir))\n", + " from get_data import GetData\n", + "\n", + " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n", + "\n", + "qlib.init(provider_uri=provider_uri, region=REG_CN)\n", + "\n", + "# Check available date range in the data\n", + "instruments = D.instruments(market=\"csi300\")\n", + "calendar = D.calendar(start_time=None, end_time=None)\n", + "print(f\"Data calendar range: {calendar[0]} ~ {calendar[-1]}\")\n", + "print(f\"Number of trading days available: {len(calendar)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "3f634456", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train period: 2016-05-08 ~ 2024-05-06\n", + "Valid period: 2024-05-06 ~ 2025-05-06\n", + "Test period: 2025-05-06 ~ 2026-05-06\n" + ] + } + ], + "source": [ + "# Market & benchmark config\n", + "market = \"csi300\"\n", + "benchmark = \"SH000300\"\n", + "\n", + "# Dynamically compute train/valid/test split based on cutoff date\n", + "train_start = (cutoff_date - timedelta(days=365 * 10)).strftime(\"%Y-%m-%d\")\n", + "train_end = (cutoff_date - timedelta(days=365 * 2)).strftime(\"%Y-%m-%d\")\n", + "valid_start = train_end\n", + "valid_end = (cutoff_date - timedelta(days=365 * 1)).strftime(\"%Y-%m-%d\")\n", + "test_start = valid_end\n", + "test_end = cutoff_str\n", + "\n", + "print(f\"Train period: {train_start} ~ {train_end}\")\n", + "print(f\"Valid period: {valid_start} ~ {valid_end}\")\n", + "print(f\"Test period: {test_start} ~ {test_end}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f388e01f", + "metadata": {}, + "source": [ + "## Step 1a: Train LGBModel" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "95737318", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing LGBModel...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:33:18,266) INFO - qlib.timer - [log.py:127] - Time cost: 49.520s | Loading data Done\n", + "[52360:MainThread](2026-05-07 08:33:18,938) INFO - qlib.timer - [log.py:127] - Time cost: 0.143s | DropnaLabel Done\n", + "[52360:MainThread](2026-05-07 08:33:20,182) INFO - qlib.timer - [log.py:127] - Time cost: 1.243s | CSZScoreNorm Done\n", + "[52360:MainThread](2026-05-07 08:33:20,219) INFO - qlib.timer - [log.py:127] - Time cost: 1.950s | fit & process data Done\n", + "[52360:MainThread](2026-05-07 08:33:20,220) INFO - qlib.timer - [log.py:127] - Time cost: 51.475s | Init data Done\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training LGBModel...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:33:20,229) WARNING - qlib.workflow - [expm.py:230] - No valid experiment found. Create a new experiment with name train_lgb.\n", + "[52360:MainThread](2026-05-07 08:33:20,238) INFO - qlib.workflow - [exp.py:258] - Experiment 634576823384325991 starts running ...\n", + "[52360:MainThread](2026-05-07 08:33:20,308) INFO - qlib.workflow - [recorder.py:345] - Recorder 026d6f6842aa4b1a97ec8614364dd29e starts running under Experiment 634576823384325991 ...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training until validation scores don't improve for 50 rounds\n", + "[20]\ttrain's l2: 0.993477\tvalid's l2: 0.995866\n", + "[40]\ttrain's l2: 0.991483\tvalid's l2: 0.995602\n", + "[60]\ttrain's l2: 0.989818\tvalid's l2: 0.995526\n", + "[80]\ttrain's l2: 0.988252\tvalid's l2: 0.995514\n", + "[100]\ttrain's l2: 0.986804\tvalid's l2: 0.995578\n", + "[120]\ttrain's l2: 0.985388\tvalid's l2: 0.995657\n", + "Early stopping, best iteration is:\n", + "[76]\ttrain's l2: 0.988573\tvalid's l2: 0.995484\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:33:29,295) INFO - qlib.timer - [log.py:127] - Time cost: 0.230s | waiting `async_log` Done\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LGBModel training completed! Recorder ID: 026d6f6842aa4b1a97ec8614364dd29e\n" + ] + } + ], + "source": [ + "###################################\n", + "# Train LGBModel\n", + "###################################\n", + "data_handler_config = {\n", + " \"start_time\": train_start,\n", + " \"end_time\": test_end,\n", + " \"fit_start_time\": train_start,\n", + " \"fit_end_time\": train_end,\n", + " \"instruments\": market,\n", + "}\n", + "\n", + "lgb_task = {\n", + " \"model\": {\n", + " \"class\": \"LGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.gbdt\",\n", + " \"kwargs\": {\n", + " \"loss\": \"mse\",\n", + " \"colsample_bytree\": 0.8879,\n", + " \"learning_rate\": 0.0421,\n", + " \"subsample\": 0.8789,\n", + " \"lambda_l1\": 205.6999,\n", + " \"lambda_l2\": 580.9768,\n", + " \"max_depth\": 8,\n", + " \"num_leaves\": 210,\n", + " \"num_threads\": 20,\n", + " },\n", + " },\n", + " \"dataset\": {\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": {\n", + " \"class\": \"Alpha158\",\n", + " \"module_path\": \"qlib.contrib.data.handler\",\n", + " \"kwargs\": data_handler_config,\n", + " },\n", + " \"segments\": {\n", + " \"train\": (train_start, train_end),\n", + " \"valid\": (valid_start, valid_end),\n", + " \"test\": (test_start, test_end),\n", + " },\n", + " },\n", + " },\n", + "}\n", + "\n", + "print(\"Initializing LGBModel...\")\n", + "lgb_model = init_instance_by_config(lgb_task[\"model\"])\n", + "lgb_dataset = init_instance_by_config(lgb_task[\"dataset\"])\n", + "\n", + "print(\"Training LGBModel...\")\n", + "with R.start(experiment_name=\"train_lgb\"):\n", + " R.log_params(**flatten_dict(lgb_task))\n", + " lgb_model.fit(lgb_dataset)\n", + " R.save_objects(trained_model=lgb_model)\n", + " lgb_rid = R.get_recorder().id\n", + "\n", + "print(f\"LGBModel training completed! Recorder ID: {lgb_rid}\")" + ] + }, + { + "cell_type": "markdown", + "id": "74f34c9b", + "metadata": {}, + "source": [ + "## Step 1b: LGBModel Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "b8f323af", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:33:29,334) WARNING - qlib.workflow - [expm.py:230] - No valid experiment found. Create a new experiment with name prediction_lgb.\n", + "[52360:MainThread](2026-05-07 08:33:29,347) INFO - qlib.workflow - [exp.py:258] - Experiment 167508167080253451 starts running ...\n", + "[52360:MainThread](2026-05-07 08:33:29,402) INFO - qlib.workflow - [recorder.py:345] - Recorder d33af8f0f9b343f8be1e9308bececb63 starts running under Experiment 167508167080253451 ...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4e501ddd113941d2a4bb19a2a6a108a8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading artifacts: 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
score
datetimeinstrument
2025-05-06SH6000000.025682
SH6000090.010076
SH6000100.043905
SH6000110.032705
SH600015-0.005264
\n", + "" + ], + "text/plain": [ + " score\n", + "datetime instrument \n", + "2025-05-06 SH600000 0.025682\n", + " SH600009 0.010076\n", + " SH600010 0.043905\n", + " SH600011 0.032705\n", + " SH600015 -0.005264" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "###################################\n", + "# LGBModel predictions\n", + "###################################\n", + "with R.start(experiment_name=\"prediction_lgb\"):\n", + " recorder = R.get_recorder(recorder_id=lgb_rid, experiment_name=\"train_lgb\")\n", + " model = recorder.load_object(\"trained_model\")\n", + "\n", + " recorder = R.get_recorder()\n", + " sr = SignalRecord(model, lgb_dataset, recorder)\n", + " sr.generate()\n", + " lgb_pred_rid = recorder.id\n", + "\n", + "recorder = R.get_recorder(recorder_id=lgb_pred_rid, experiment_name=\"prediction_lgb\")\n", + "lgb_pred_df = recorder.load_object(\"pred.pkl\")\n", + "print(f\"LGBModel predictions shape: {lgb_pred_df.shape}\")\n", + "print(f\"Prediction date range: {lgb_pred_df.index.get_level_values('datetime').min()} ~ {lgb_pred_df.index.get_level_values('datetime').max()}\")\n", + "lgb_pred_df.head(5)" + ] + }, + { + "cell_type": "markdown", + "id": "319d56d1", + "metadata": {}, + "source": [ + "## Step 1c: Train ALSTM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc243571", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing ALSTM...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:49:12,091) INFO - qlib.ALSTM - [pytorch_alstm.py:59] - ALSTM pytorch version...\n", + "[52360:MainThread](2026-05-07 08:49:12,093) INFO - qlib.ALSTM - [pytorch_alstm.py:76] - ALSTM parameters setting:\n", + "d_feat : 6\n", + "hidden_size : 128\n", + "num_layers : 2\n", + "dropout : 0.1\n", + "n_epochs : 100\n", + "lr : 0.001\n", + "metric : loss\n", + "batch_size : 2000\n", + "early_stop : 10\n", + "optimizer : adam\n", + "loss_type : mse\n", + "device : cpu\n", + "use_GPU : False\n", + "seed : None\n", + "[52360:MainThread](2026-05-07 08:49:12,095) INFO - qlib.ALSTM - [pytorch_alstm.py:119] - model:\n", + "ALSTMModel(\n", + " (net): Sequential(\n", + " (fc_in): Linear(in_features=6, out_features=128, bias=True)\n", + " (act): Tanh()\n", + " )\n", + " (rnn): GRU(128, 128, num_layers=2, batch_first=True, dropout=0.1)\n", + " (fc_out): Linear(in_features=256, out_features=1, bias=True)\n", + " (att_net): Sequential(\n", + " (att_fc_in): Linear(in_features=128, out_features=64, bias=True)\n", + " (att_dropout): Dropout(p=0.1, inplace=False)\n", + " (att_act): Tanh()\n", + " (att_fc_out): Linear(in_features=64, out_features=1, bias=False)\n", + " (att_softmax): Softmax(dim=1)\n", + " )\n", + ")\n", + "[52360:MainThread](2026-05-07 08:49:12,096) INFO - qlib.ALSTM - [pytorch_alstm.py:120] - model size: 0.1980 MB\n", + "[52360:MainThread](2026-05-07 08:50:38,136) INFO - qlib.timer - [log.py:127] - Time cost: 86.036s | Loading data Done\n", + "[52360:MainThread](2026-05-07 08:52:00,908) INFO - qlib.timer - [log.py:127] - Time cost: 81.608s | ProcessInf Done\n", + "[52360:MainThread](2026-05-07 08:52:04,294) INFO - qlib.timer - [log.py:127] - Time cost: 3.384s | ZScoreNorm Done\n", + "[52360:MainThread](2026-05-07 08:52:06,581) INFO - qlib.timer - [log.py:127] - Time cost: 2.283s | Fillna Done\n", + "[52360:MainThread](2026-05-07 08:52:07,060) INFO - qlib.timer - [log.py:127] - Time cost: 0.208s | DropnaLabel Done\n", + "[52360:MainThread](2026-05-07 08:52:08,525) INFO - qlib.timer - [log.py:127] - Time cost: 1.464s | CSZScoreNorm Done\n", + "[52360:MainThread](2026-05-07 08:52:08,559) INFO - qlib.timer - [log.py:127] - Time cost: 90.422s | fit & process data Done\n", + "[52360:MainThread](2026-05-07 08:52:08,561) INFO - qlib.timer - [log.py:127] - Time cost: 176.462s | Init data Done\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training ALSTM (this may take several minutes)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[52360:MainThread](2026-05-07 08:52:08,573) INFO - qlib.workflow - [exp.py:258] - Experiment 689530161949919692 starts running ...\n", + "[52360:MainThread](2026-05-07 08:52:08,620) INFO - qlib.workflow - [recorder.py:345] - Recorder 3c3b1def72834ab88fb4670272f91535 starts running under Experiment 689530161949919692 ...\n", + "[52360:MainThread](2026-05-07 08:52:09,948) INFO - qlib.ALSTM - [pytorch_alstm.py:235] - training...\n", + "[52360:MainThread](2026-05-07 08:52:09,950) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch0:\n", + "[52360:MainThread](2026-05-07 08:52:09,952) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 08:56:22,311) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 08:57:44,100) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.996009, valid -0.996626\n", + "[52360:MainThread](2026-05-07 08:57:44,107) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch1:\n", + "[52360:MainThread](2026-05-07 08:57:44,109) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:02:53,568) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:04:25,043) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.994779, valid -0.996294\n", + "[52360:MainThread](2026-05-07 09:04:25,048) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch2:\n", + "[52360:MainThread](2026-05-07 09:04:25,049) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:09:19,621) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:10:45,329) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.994323, valid -0.996694\n", + "[52360:MainThread](2026-05-07 09:10:45,331) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch3:\n", + "[52360:MainThread](2026-05-07 09:10:45,333) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:15:32,668) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:16:57,597) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.993535, valid -0.996118\n", + "[52360:MainThread](2026-05-07 09:16:57,600) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch4:\n", + "[52360:MainThread](2026-05-07 09:16:57,602) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:21:44,215) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:23:08,489) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.993159, valid -0.996540\n", + "[52360:MainThread](2026-05-07 09:23:08,490) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch5:\n", + "[52360:MainThread](2026-05-07 09:23:08,492) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:27:52,845) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:29:16,789) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.991992, valid -0.997440\n", + "[52360:MainThread](2026-05-07 09:29:16,791) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch6:\n", + "[52360:MainThread](2026-05-07 09:29:16,792) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n", + "[52360:MainThread](2026-05-07 09:34:13,673) INFO - qlib.ALSTM - [pytorch_alstm.py:242] - evaluating...\n", + "[52360:MainThread](2026-05-07 09:35:41,325) INFO - qlib.ALSTM - [pytorch_alstm.py:245] - train -0.991235, valid -0.999107\n", + "[52360:MainThread](2026-05-07 09:35:41,327) INFO - qlib.ALSTM - [pytorch_alstm.py:239] - Epoch7:\n", + "[52360:MainThread](2026-05-07 09:35:41,330) INFO - qlib.ALSTM - [pytorch_alstm.py:240] - training...\n" + ] + } + ], + "source": [ + "###################################\n", + "# Train ALSTM (Attention LSTM)\n", + "###################################\n", + "alstm_task = {\n", + " \"model\": {\n", + " \"class\": \"ALSTM\",\n", + " \"module_path\": \"qlib.contrib.model.pytorch_alstm\",\n", + " \"kwargs\": {\n", + " \"d_feat\": 6,\n", + " \"hidden_size\": 128,\n", + " \"num_layers\": 2,\n", + " \"dropout\": 0.1,\n", + " \"n_epochs\": 100,\n", + " \"lr\": 0.001,\n", + " \"early_stop\": 10,\n", + " \"batch_size\": 2000,\n", + " \"metric\": \"loss\",\n", + " \"loss\": \"mse\",\n", + " \"GPU\": 0,\n", + " },\n", + " },\n", + " \"dataset\": {\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": {\n", + " \"class\": \"Alpha360\",\n", + " \"module_path\": \"qlib.contrib.data.handler\",\n", + " \"kwargs\": data_handler_config,\n", + " },\n", + " \"segments\": {\n", + " \"train\": (train_start, train_end),\n", + " \"valid\": (valid_start, valid_end),\n", + " \"test\": (test_start, test_end),\n", + " },\n", + " },\n", + " },\n", + "}\n", + "\n", + "print(\"Initializing ALSTM...\")\n", + "alstm_model = init_instance_by_config(alstm_task[\"model\"])\n", + "alstm_dataset = init_instance_by_config(alstm_task[\"dataset\"])\n", + "\n", + "print(\"Training ALSTM (this may take several minutes)...\")\n", + "with R.start(experiment_name=\"train_alstm\"):\n", + " R.log_params(**flatten_dict(alstm_task))\n", + " alstm_model.fit(alstm_dataset)\n", + " R.save_objects(trained_model=alstm_model)\n", + " alstm_rid = R.get_recorder().id\n", + "\n", + "print(f\"ALSTM training completed! Recorder ID: {alstm_rid}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4467deb2", + "metadata": {}, + "source": [ + "## Step 1d: ALSTM Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39b70841", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# ALSTM predictions\n", + "###################################\n", + "with R.start(experiment_name=\"prediction_alstm\"):\n", + " recorder = R.get_recorder(recorder_id=alstm_rid, experiment_name=\"train_alstm\")\n", + " model = recorder.load_object(\"trained_model\")\n", + "\n", + " recorder = R.get_recorder()\n", + " sr = SignalRecord(model, alstm_dataset, recorder)\n", + " sr.generate()\n", + " alstm_pred_rid = recorder.id\n", + "\n", + "recorder = R.get_recorder(recorder_id=alstm_pred_rid, experiment_name=\"prediction_alstm\")\n", + "alstm_pred_df = recorder.load_object(\"pred.pkl\")\n", + "print(f\"ALSTM predictions shape: {alstm_pred_df.shape}\")\n", + "print(f\"Prediction date range: {alstm_pred_df.index.get_level_values('datetime').min()} ~ {alstm_pred_df.index.get_level_values('datetime').max()}\")\n", + "alstm_pred_df.head(5)" + ] + }, + { + "cell_type": "markdown", + "id": "3a6f9565", + "metadata": {}, + "source": [ + "## Step 2.5: Backtest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d84b1d36", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# Run backtest — LGBModel\n", + "###################################\n", + "STRATEGY_CONFIG = {\n", + " \"topk\": 50,\n", + " \"n_drop\": 5,\n", + " \"signal\": lgb_pred_df,\n", + "}\n", + "EXECUTOR_CONFIG = {\n", + " \"time_per_step\": \"day\",\n", + " \"generate_portfolio_metrics\": True,\n", + "}\n", + "backtest_config = {\n", + " \"start_time\": test_start,\n", + " \"end_time\": test_end,\n", + " \"account\": 100000000,\n", + " \"benchmark\": benchmark,\n", + " \"exchange_kwargs\": {\n", + " \"freq\": \"day\",\n", + " \"limit_threshold\": 0.095,\n", + " \"deal_price\": \"close\",\n", + " \"open_cost\": 0.0005,\n", + " \"close_cost\": 0.0015,\n", + " \"min_cost\": 5,\n", + " },\n", + "}\n", + "\n", + "strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n", + "executor_obj = SimulatorExecutor(**EXECUTOR_CONFIG)\n", + "portfolio_metric_dict, indicator_dict = backtest(\n", + " executor=executor_obj, strategy=strategy_obj, **backtest_config\n", + ")\n", + "\n", + "analysis_freq = f\"{Freq.parse('day')[0]}{Freq.parse('day')[1]}\"\n", + "lgb_report_normal_df, lgb_positions = portfolio_metric_dict.get(analysis_freq)\n", + "\n", + "print(\"=\" * 60)\n", + "print(f\" LGBModel 回测结果 ({test_start} ~ {test_end})\")\n", + "print(\"=\" * 60)\n", + "print(f\"回测天数: {len(lgb_report_normal_df)}\")\n", + "print(\"\\n===== LGBModel 策略收益(含成本) =====\")\n", + "display(risk_analysis(lgb_report_normal_df[\"return\"] - lgb_report_normal_df[\"bench\"] - lgb_report_normal_df[\"cost\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18185385", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# Run backtest — ALSTM\n", + "###################################\n", + "STRATEGY_CONFIG[\"signal\"] = alstm_pred_df\n", + "strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)\n", + "portfolio_metric_dict, indicator_dict = backtest(\n", + " executor=executor_obj, strategy=strategy_obj, **backtest_config\n", + ")\n", + "\n", + "alstm_report_normal_df, alstm_positions = portfolio_metric_dict.get(analysis_freq)\n", + "\n", + "print(\"=\" * 60)\n", + "print(f\" ALSTM 回测结果 ({test_start} ~ {test_end})\")\n", + "print(\"=\" * 60)\n", + "print(f\"回测天数: {len(alstm_report_normal_df)}\")\n", + "print(\"\\n===== ALSTM 策略收益(含成本) =====\")\n", + "display(risk_analysis(alstm_report_normal_df[\"return\"] - alstm_report_normal_df[\"bench\"] - alstm_report_normal_df[\"cost\"]))\n", + "\n", + "# ── 对比汇总 ──\n", + "lgb_return = (1 + lgb_report_normal_df[\"return\"] - lgb_report_normal_df[\"cost\"]).cumprod().values[-1] - 1\n", + "alstm_return = (1 + alstm_report_normal_df[\"return\"] - alstm_report_normal_df[\"cost\"]).cumprod().values[-1] - 1\n", + "bench_return = (1 + alstm_report_normal_df[\"bench\"]).cumprod().values[-1] - 1\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(f\" 双模型对比汇总\")\n", + "print(\"=\" * 60)\n", + "print(f\" LGBModel 累计收益: {lgb_return:.4%}\")\n", + "print(f\" ALSTM 累计收益: {alstm_return:.4%}\")\n", + "print(f\" 基准 累计收益: {bench_return:.4%}\")\n", + "print(f\" LGBModel 超额收益: {lgb_return - bench_return:.4%}\")\n", + "print(f\" ALSTM 超额收益: {alstm_return - bench_return:.4%}\")\n", + "if alstm_return > lgb_return:\n", + " print(f\"\\n >>> ALSTM 表现更优,超额领先 {alstm_return - lgb_return:.4%}\")\n", + "else:\n", + " print(f\"\\n >>> LGBModel 表现更优,超额领先 {lgb_return - alstm_return:.4%}\")" + ] + }, + { + "cell_type": "markdown", + "id": "220ced36", + "metadata": {}, + "source": [ + "## Step 2.6: 收益率曲线 — 双模型对比" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "701c424e", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# 绘制双模型收益率曲线对比\n", + "###################################\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as mticker\n", + "\n", + "# ── 计算各模型累计收益 ──\n", + "def calc_cum_returns(report_df):\n", + " strategy_net = report_df[\"return\"] - report_df[\"cost\"]\n", + " return (1 + strategy_net).cumprod()\n", + "\n", + "cum_lgb = calc_cum_returns(lgb_report_normal_df)\n", + "cum_alstm = calc_cum_returns(alstm_report_normal_df)\n", + "cum_bench = (1 + alstm_report_normal_df[\"bench\"].fillna(0)).cumprod()\n", + "\n", + "final_lgb = cum_lgb.values[-1] - 1\n", + "final_alstm = cum_alstm.values[-1] - 1\n", + "final_bench = cum_bench.values[-1] - 1\n", + "\n", + "fig, axes = plt.subplots(2, 1, figsize=(16, 12))\n", + "\n", + "# ── 上图: LGBModel vs ALSTM vs 基准 累计收益 ──\n", + "ax1 = axes[0]\n", + "ax1.plot(cum_lgb.index, cum_lgb.values, label=\"LGBModel(含成本)\", color=\"#2980b9\", linewidth=1.8)\n", + "ax1.plot(cum_alstm.index, cum_alstm.values, label=\"ALSTM(含成本)\", color=\"#e67e22\", linewidth=1.8)\n", + "ax1.plot(cum_bench.index, cum_bench.values, label=f\"基准 ({benchmark})\", color=\"#95a5a6\", linewidth=1.5, linestyle=\"--\")\n", + "ax1.set_title(f\"双模型累计收益对比 ({test_start} ~ {test_end})\", fontsize=14, fontweight=\"bold\")\n", + "ax1.set_ylabel(\"累计收益\", fontsize=11)\n", + "ax1.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n", + "ax1.legend(loc=\"upper left\", fontsize=9)\n", + "ax1.grid(True, alpha=0.3)\n", + "\n", + "# 标注最终收益\n", + "ax1.text(0.02, 0.95,\n", + " f\"LGBModel 最终收益: {final_lgb:.2%}\\n\"\n", + " f\"ALSTM 最终收益: {final_alstm:.2%}\\n\"\n", + " f\"基准 最终收益: {final_bench:.2%}\",\n", + " transform=ax1.transAxes, fontsize=10, verticalalignment=\"top\",\n", + " bbox=dict(boxstyle=\"round\", facecolor=\"white\", alpha=0.8))\n", + "\n", + "# ── 下图: 超额收益对比 ──\n", + "ax2 = axes[1]\n", + "lgb_excess = cum_lgb.values.flatten() - cum_bench.values.flatten()\n", + "alstm_excess = cum_alstm.values.flatten() - cum_bench.values.flatten()\n", + "ax2.plot(cum_lgb.index, lgb_excess, label=\"LGBModel 超额收益\", color=\"#2980b9\", linewidth=1.5)\n", + "ax2.plot(cum_alstm.index, alstm_excess, label=\"ALSTM 超额收益\", color=\"#e67e22\", linewidth=1.5)\n", + "ax2.axhline(y=0, color=\"black\", linewidth=0.5, linestyle=\"--\")\n", + "ax2.fill_between(cum_lgb.index, lgb_excess, 0, alpha=0.08, color=\"#2980b9\")\n", + "ax2.fill_between(cum_alstm.index, alstm_excess, 0, alpha=0.08, color=\"#e67e22\")\n", + "ax2.set_title(\"双模型超额收益对比\", fontsize=14, fontweight=\"bold\")\n", + "ax2.set_ylabel(\"超额收益\", fontsize=11)\n", + "ax2.yaxis.set_major_formatter(mticker.PercentFormatter(1.0))\n", + "ax2.legend(loc=\"upper left\", fontsize=9)\n", + "ax2.grid(True, alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3166b757", + "metadata": {}, + "source": [ + "## Step 3: Today's Stock Recommendations (Best Model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "059be5cf", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# Get today's stock recommendations\n", + "# using the best-performing model\n", + "###################################\n", + "\n", + "# Auto-select best model based on backtest returns\n", + "lgb_final = (1 + lgb_report_normal_df[\"return\"] - lgb_report_normal_df[\"cost\"]).cumprod().values[-1] - 1\n", + "alstm_final = (1 + alstm_report_normal_df[\"return\"] - alstm_report_normal_df[\"cost\"]).cumprod().values[-1] - 1\n", + "\n", + "if alstm_final > lgb_final:\n", + " best_model_name = \"ALSTM\"\n", + " pred_df = alstm_pred_df\n", + "else:\n", + " best_model_name = \"LGBModel\"\n", + " pred_df = lgb_pred_df\n", + "\n", + "print(f\"Best model: {best_model_name}\")\n", + "print(f\" LGBModel final return: {lgb_final:.4%}\")\n", + "print(f\" ALSTM final return: {alstm_final:.4%}\")\n", + "\n", + "# Get the latest trading day's predictions\n", + "latest_date = pred_df.index.get_level_values(\"datetime\").max()\n", + "latest_pred = pred_df.loc[pred_df.index.get_level_values(\"datetime\") == latest_date]\n", + "latest_pred = latest_pred.droplevel(\"datetime\")\n", + "\n", + "print(f\"\\nLatest trading day with predictions: {latest_date}\")\n", + "\n", + "# Rank stocks by prediction score (descending) and pick top 20\n", + "top_n = 20\n", + "top_stocks = latest_pred.sort_values(\"score\", ascending=False).head(top_n)\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(f\" Today's ({today}) Top {top_n} Recommended Stocks\")\n", + "print(f\" Model: {best_model_name} | Data as of {latest_date}\")\n", + "print(f\"{'='*60}\\n\")\n", + "\n", + "for rank, (stock, row) in enumerate(top_stocks.iterrows(), 1):\n", + " print(f\" #{rank:<4} {stock:<12} Score: {row['score']:.6f}\")\n", + "\n", + "print(f\"\\n{'='*60}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0fc4747", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# Detailed recommendation summary\n", + "###################################\n", + "\n", + "# Create a clean recommendation table\n", + "recommendations = top_stocks.copy()\n", + "recommendations.index.name = \"Stock\"\n", + "recommendations.columns = [\"Prediction Score\"]\n", + "recommendations.insert(0, \"Rank\", range(1, len(recommendations) + 1))\n", + "\n", + "# Style the output\n", + "print(f\"\\n Recommendation Summary \")\n", + "print(f\"-\" * 50)\n", + "print(f\" Model : {best_model_name}\")\n", + "print(f\" Data cutoff : {latest_date}\")\n", + "print(f\" Recommendation date : {today}\")\n", + "print(f\" Market : {market.upper()}\")\n", + "print(f\" Top N : {top_n}\")\n", + "print(f\"-\" * 50)\n", + "\n", + "display(recommendations.style.format({\"Prediction Score\": \"{:.6f}\"}).set_caption(f\"Today's Stock Recommendations ({best_model_name})\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c3c745d", + "metadata": {}, + "outputs": [], + "source": [ + "###################################\n", + "# Visualize top recommendations\n", + "###################################\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 6))\n", + "scores = top_stocks[\"score\"].values\n", + "stocks = top_stocks.index.values\n", + "colors = [\"#e74c3c\" if i < 3 else \"#3498db\" for i in range(len(stocks))]\n", + "\n", + "bars = ax.barh(range(len(stocks)), scores, color=colors)\n", + "ax.set_yticks(range(len(stocks)))\n", + "ax.set_yticklabels(stocks)\n", + "ax.invert_yaxis()\n", + "ax.set_xlabel(\"Prediction Score\", fontsize=12)\n", + "ax.set_title(f\"Top {top_n} Stock Recommendations ({today}) — {best_model_name}\", fontsize=14, fontweight=\"bold\")\n", + "ax.axvline(x=0, color=\"black\", linewidth=0.5)\n", + "\n", + "# Add score labels\n", + "for i, (score, stock) in enumerate(zip(scores, stocks)):\n", + " ax.text(score + 0.001, i, f\"{score:.4f}\", va=\"center\", fontsize=9)\n", + "\n", + "ax.legend(\n", + " [plt.Rectangle((0, 0), 1, 1, color=\"#e74c3c\"), plt.Rectangle((0, 0), 1, 1, color=\"#3498db\")],\n", + " [\"Top 3 (Strong Buy)\", f\"Top 4-{top_n} (Buy)\"],\n", + " loc=\"lower right\",\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "qlib", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.20" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 69262fcbbad..f895166a0e5 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -45,9 +45,9 @@ def __init__( subscribe_fields: list = [], limit_threshold: Union[Tuple[str, str], float, None] = None, volume_threshold: Union[tuple, dict, None] = None, - open_cost: float = 0.0015, - close_cost: float = 0.0025, - min_cost: float = 5.0, + open_cost: float = 0.0006, + close_cost: float = 0.0001, + min_cost: float = 0.0, impact_cost: float = 0.0, extra_quote: pd.DataFrame = None, quote_cls: Type[BaseQuote] = NumpyQuote,