diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ee97bb8..468d419 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,7 +6,6 @@ name: Publish SimTradeLab Package on: release: types: [published] - # 可选:允许手动触发 workflow_dispatch: jobs: @@ -25,46 +24,9 @@ jobs: uses: snok/install-poetry@v1 with: version: latest - virtualenvs-create: true - virtualenvs-in-project: true - - - name: Install system dependencies - run: | - sudo apt-get update - - # 从源码编译安装ta-lib - wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz - tar -xzf ta-lib-0.4.0-src.tar.gz - cd ta-lib/ - ./configure --prefix=/usr - make - sudo make install - cd .. - rm -rf ta-lib ta-lib-0.4.0-src.tar.gz - - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-3.11-${{ hashFiles('**/poetry.lock') }} - - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - - - name: Install project - run: poetry install --no-interaction - name: Build release distributions - run: | - poetry build - - - name: Check build artifacts - run: | - ls -la dist/ - echo "Built packages:" - find dist/ -name "*.whl" -o -name "*.tar.gz" + run: poetry build - name: Upload distributions uses: actions/upload-artifact@v4 @@ -92,13 +54,10 @@ jobs: post-publish: runs-on: ubuntu-latest - needs: - - pypi-publish + needs: pypi-publish if: github.event_name == 'release' permissions: contents: write - pull-requests: write - issues: write steps: - uses: actions/checkout@v4 @@ -108,36 +67,25 @@ jobs: with: python-version: "3.11" + - name: Install TA-Lib + run: | + wget -q http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz + tar -xzf ta-lib-0.4.0-src.tar.gz + cd ta-lib && ./configure --prefix=/usr && make && sudo make install + - name: Verify PyPI publication run: | - # 等待PyPI索引更新 sleep 30 - - # 获取不带v前缀的版本号 VERSION=${{ github.event.release.tag_name }} VERSION=${VERSION#v} - - # 尝试安装刚发布的包 pip install simtradelab==$VERSION + python -c 'import simtradelab; print("✅ SimTradeLab " + simtradelab.__version__ + " installed successfully")' - # 验证安装 - python -c 'import simtradelab; from simtradelab.backtest.runner import BacktestRunner; print("✅ SimTradeLab " + simtradelab.__version__ + " installed successfully")' - - - name: Install script dependencies + - name: Generate and Update Release Notes run: | pip install gitpython - - - name: Generate Release Notes - id: release_notes - run: | - # 生成Release Notes python scripts/generate_release_notes.py ${{ github.event.release.tag_name }} --output release_notes.md - # 读取生成的内容并设置为输出 - echo "RELEASE_NOTES<> $GITHUB_OUTPUT - cat release_notes.md >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - - name: Update Release with Generated Notes uses: actions/github-script@v7 with: @@ -145,37 +93,9 @@ jobs: script: | const fs = require('fs'); const releaseNotes = fs.readFileSync('release_notes.md', 'utf8'); - - // 更新Release的描述 await github.rest.repos.updateRelease({ owner: context.repo.owner, repo: context.repo.repo, release_id: context.payload.release.id, body: releaseNotes }); - - - name: Create success comment - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const tagName = '${{ github.event.release.tag_name }}'; - const version = tagName.replace(/^v/, ''); - - github.rest.issues.createComment({ - issue_number: context.payload.release.id, - owner: context.repo.owner, - repo: context.repo.repo, - body: `🎉 SimTradeLab ${tagName} has been successfully published to PyPI! - - 📦 **Installation:** - \`\`\`bash - pip install simtradelab==${version} - \`\`\` - - 🔗 **PyPI Link:** https://pypi.org/project/simtradelab/${version}/ - - ✅ **Verification:** Package installation verified successfully. - - 📋 **Release Notes:** Automatically generated and updated in the release description.` - }) diff --git a/README.md b/README.md index d1dee58..981cc4c 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ SimTradeLab(深测Lab) 是一个由社区独立开发的开源策略回测框架,灵感来源于 PTrade 的事件驱动架构。它具备完全自主的实现与出色的扩展能力,为策略开发者提供一个轻量级、结构清晰、模块可插拔的策略验证环境。框架无需依赖 PTrade 即可独立运行,但与其语法保持高度兼容。所有在 SimTradeLab 中编写的策略可无缝迁移至 PTrade 平台,反之亦然,两者之间的 API 可直接互通使用。详情参考:https://github.com/kay-ou/ptradeAPI 项目。 **核心特性:** -- ✅ **52个核心API** - 股票交易、数据查询、技术指标完整支持(34% PTrade API完成度) +- ✅ **46个回测/研究API** - 股票日/分钟线回测场景100%覆盖 - ⚡ **20-30倍性能提升** - 本地回测比PTrade平台快20-30倍 - 🚀 **数据常驻内存** - 单例模式,首次加载后常驻,二次运行秒级启动 - 💾 **多级智能缓存** - LRU缓存(MA/VWAP/复权/历史数据),命中率>95% @@ -59,9 +59,12 @@ pip install simtradelab[optimizer] 将数据文件放到 `data/` 目录: ``` data/ -├── price/ # 股票价格数据 -├── fundamentals/ # 基本面数据 -└── exrights/ # 除权除息数据 +├── stocks/ # 股票日线数据 +├── stocks_1m/ # 股票分钟数据(分钟回测需要) +├── valuation/ # 估值数据 +├── fundamentals/ # 财务数据 +├── exrights/ # 除权数据 +└── metadata/ # 元数据 ``` **数据获取:** 推荐使用 [SimTradeData](https://github.com/kay-ou/SimTradeData) 项目获取A股历史数据 @@ -77,7 +80,7 @@ def initialize(context): context.stocks = ['600519.SS', '000858.SZ'] def handle_data(context, data): - """每日交易逻辑""" + """交易逻辑(日线每日调用,分钟线每分钟调用)""" for stock in context.stocks: hist = get_history(20, '1d', 'close', [stock], is_dict=True) if stock not in hist: @@ -105,14 +108,18 @@ def handle_data(context, data): ```python # run_backtest.py from simtradelab.backtest.runner import BacktestRunner +from simtradelab.backtest.config import BacktestConfig -runner = BacktestRunner() -runner.run( +config = BacktestConfig( strategy_name='my_strategy', start_date='2024-01-01', end_date='2024-12-31', - initial_capital=1000000.0 + initial_capital=1000000.0, + frequency='1d' # '1d'日线回测(默认),'1m'分钟回测 ) + +runner = BacktestRunner() +runner.run(config=config) ``` 执行: @@ -164,17 +171,17 @@ stocks = api.get_index_stocks('000300.SS', date='2024-01-01') ## 📚 API文档 -已实现52个核心API(34% PTrade API完成度),涵盖股票交易、数据查询、技术指标、策略配置等核心功能。 +已实现46个回测/研究API,股票回测场景100%覆盖。 **核心API分类:** | 类别 | 完成度 | 说明 | |------|--------|------| -| 交易API | ✅ | order, order_target, order_value, order_target_value, cancel_order | +| 交易API | ✅ | order, order_target, order_value, order_target_value, cancel_order, get_positions, get_trades | | 数据查询 | ✅ | get_price, get_history, get_fundamentals, get_stock_info | | 板块信息 | ✅ | get_index_stocks, get_industry_stocks, get_stock_blocks | | 技术指标 | ✅ | get_MACD, get_KDJ, get_RSI, get_CCI | -| 策略配置 | ✅ | set_benchmark, set_commission, set_slippage, set_universe | +| 策略配置 | ✅ | set_benchmark, set_commission, set_slippage, set_universe, set_parameters, set_yesterday_position | | 生命周期 | ✅ | initialize, before_trading_start, handle_data, after_trading_end | | 融资融券 | ❌ | 19个API未实现 | | 期货/期权 | ❌ | 22个API未实现 | @@ -211,7 +218,7 @@ stocks = api.get_index_stocks('000300.SS', date='2024-01-01') ``` SimTradeLab/ ├── src/simtradelab/ -│ ├── ptrade/ # PTrade API模拟层(52个核心API) +│ ├── ptrade/ # PTrade API模拟层(日/分钟线100%覆盖回测场景API) │ ├── backtest/ # 回测引擎(统计、优化、配置) │ ├── research/ # Research模式(无生命周期限制) │ ├── service/ # 核心服务(数据常驻) @@ -249,10 +256,9 @@ SimTradeLab/ ## 🚧 待改进与已知问题 ### 主要限制 -- ❌ 不支持分钟线数据(仅日线) - ❌ 不支持实盘交易(仅回测) - ⚠️ 测试覆盖不全面(策略驱动测试中) -- ⏳ 99个PTrade API未实现(融资融券、期货、期权等) +- ⏳ 实盘PTrade API未实现(融资融券、期货、期权等) ### 计划改进 - 🔧 命令行工具(目前需要修改Python文件) @@ -267,8 +273,24 @@ SimTradeLab/ **Q: 如何修改初始资金?** ```python -runner.run(initial_capital=2000000.0) # 修改这里 +config = BacktestConfig( + strategy_name='my_strategy', + start_date='2024-01-01', + end_date='2024-12-31', + initial_capital=2000000.0 # 修改这里 +) +``` + +**Q: 如何使用分钟回测?** +```python +config = BacktestConfig( + strategy_name='my_strategy', + start_date='2024-01-01', + end_date='2024-12-31', + frequency='1m' # 设置为分钟回测 +) ``` +注意:分钟回测需要在 `data/stocks_1m/` 目录下准备分钟数据。 **Q: 回测太慢怎么办?** - 减少股票数量或缩短回测时间 @@ -284,7 +306,7 @@ runner.run(initial_capital=2000000.0) # 修改这里 可能是缓存问题,尝试清理并重建: ```bash cd data -rm -rf .keys_cache/ ptrade_adj_pre.h5 ptrade_dividend_cache.h5 +rm -rf .keys_cache/ ``` 详见 [INSTALLATION.md - Q7](docs/INSTALLATION.md#q7-数据加载异常或缓存问题) diff --git a/scripts/setup_typeshed.sh b/scripts/setup_typeshed.sh new file mode 100755 index 0000000..cea529b --- /dev/null +++ b/scripts/setup_typeshed.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# 在 pyright 自带 typeshed 的 builtins.pyi 末尾追加 Ptrade API 类型声明 +# 用途:让 VS Code Pylance 识别策略代码中的 Ptrade API 全局函数 +# pyright 更新后需重跑 +set -e + +PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +BUILTINS="$PROJECT_ROOT/.venv/lib/python3.12/site-packages/pyright/dist/dist/typeshed-fallback/stdlib/builtins.pyi" + +if [ ! -f "$BUILTINS" ]; then + echo "ERROR: pyright builtins.pyi not found at $BUILTINS" + echo "Run: poetry install" + exit 1 +fi + +# 幂等:已有标记则跳过 +if grep -q '# PTrade API' "$BUILTINS"; then + echo "Ptrade API stubs already present, skipping." + exit 0 +fi + +cat >> "$BUILTINS" << 'PTRADE_EOF' + + +# ============================================================ +# PTrade API - 注入到策略全局命名空间的 API 函数和对象 +# ============================================================ + +import logging as _logging +import numpy as _np +import pandas as _pd + +class _Position: + security: str + amount: int + cost_basis: float + avg_cost: float + def __getattr__(self, name: str) -> Any: ... + +class _Portfolio: + cash: float + portfolio_value: float + positions: dict[str, _Position] + starting_cash: float + positions_value: float + returns: float + def __getattr__(self, name: str) -> Any: ... + +class _Context: + current_dt: _pd.Timestamp + portfolio: _Portfolio + benchmark: str + def __setattr__(self, name: str, value: Any) -> None: ... + def __getattr__(self, name: str) -> Any: ... + +class _PanelLike: + def __getitem__(self, key: str) -> Any: ... + def __getattr__(self, name: str) -> Any: ... + +context: _Context +g: types.SimpleNamespace +log: _logging.Logger + +FUNDAMENTAL_TABLES: dict[str, list[str]] + +def get_research_path() -> str: ... +def get_Ashares(date: str = ...) -> list[str]: ... +def get_trade_days(start_date: str = ..., end_date: str = ..., count: int = ...) -> list[str]: ... +def get_all_trades_days(date: str = ...) -> list[str]: ... +def get_trading_day(day: int = 0) -> str | None: ... +def is_trade() -> bool: ... +def get_fundamentals(security: str | list[str], table: str, fields: list[str], date: str = ...) -> _pd.DataFrame: ... +def get_price(security: str | list[str], start_date: str = ..., end_date: str = ..., frequency: str = "1d", fields: str | list[str] = ..., fq: str = ..., count: int = ...) -> _pd.DataFrame | _PanelLike: ... +def get_history(count: int, frequency: str = "1d", field: str | list[str] = "close", security_list: str | list[str] = ..., fq: str = ..., include: bool = False, fill: str = "nan", is_dict: bool = False) -> _pd.DataFrame | dict | _PanelLike: ... +def get_snapshot(security: str | list[str]) -> dict | _pd.DataFrame: ... +def get_stock_blocks(stock: str) -> dict: ... +def get_stock_info(stocks: str | list[str], field: str | list[str] = ...) -> dict[str, dict]: ... +def get_stock_name(stocks: str | list[str]) -> str | dict[str, str]: ... +def get_stock_status(stocks: str | list[str], query_type: str = "ST", query_date: str = ...) -> dict[str, bool]: ... +def get_stock_exrights(stock_code: str, date: str = ...) -> _pd.DataFrame | None: ... +def get_index_stocks(index_code: str, date: str = ...) -> list[str]: ... +def get_industry_stocks(industry_code: str = ...) -> dict | list[str]: ... +def check_limit(security: str | list[str], query_date: str = ...) -> dict[str, int]: ... +def order(security: str, amount: int, limit_price: float = ...) -> str | None: ... +def order_target(security: str, amount: int, limit_price: float = ...) -> str | None: ... +def order_value(security: str, value: float, limit_price: float = ...) -> str | None: ... +def order_target_value(security: str, value: float, limit_price: float = ...) -> str | None: ... +def cancel_order(order: object) -> bool: ... +def get_open_orders() -> list: ... +def get_orders(security: str = ...) -> list: ... +def get_order(order_id: str) -> object | None: ... +def get_trades(security: str = ...) -> list: ... +def get_position(security: str) -> _Position | None: ... +def get_positions(security_list: list[str] = ...) -> dict[str, _Position]: ... +def set_benchmark(benchmark: str) -> None: ... +def set_universe(stocks: str | list[str]) -> None: ... +def set_commission(commission_ratio: float = 0.0003, min_commission: float = 5.0, type: str = "STOCK") -> None: ... +def set_slippage(slippage: float = 0.0) -> None: ... +def set_fixed_slippage(fixedslippage: float = 0.001) -> None: ... +def set_limit_mode(limit_mode: str = "LIMIT") -> None: ... +def set_volume_ratio(volume_ratio: float = 0.25) -> None: ... +def set_yesterday_position(poslist: list[dict]) -> None: ... +def set_parameters(params: dict) -> None: ... +def run_interval(context: object, func: Callable[..., object], seconds: int = 10) -> None: ... +def run_daily(context: object, func: Callable[..., object], time: str = "9:31") -> None: ... +def get_user_name() -> str: ... +def convert_position_from_csv(file_path: str) -> list[dict]: ... +def get_MACD(close: _np.ndarray, short: int = 12, long: int = 26, m: int = 9) -> tuple[_np.ndarray, _np.ndarray, _np.ndarray]: ... +def get_KDJ(high: _np.ndarray, low: _np.ndarray, close: _np.ndarray, n: int = 9, m1: int = 3, m2: int = 3) -> tuple[_np.ndarray, _np.ndarray, _np.ndarray]: ... +def get_RSI(close: _np.ndarray, n: int = 6) -> _np.ndarray: ... +def get_CCI(high: _np.ndarray, low: _np.ndarray, close: _np.ndarray, n: int = 14) -> _np.ndarray: ... +def prebuild_date_index(stocks: list[str] | None = ...) -> None: ... +def get_stock_date_index(stock: str) -> tuple[dict, list]: ... +PTRADE_EOF + +echo "Ptrade API stubs appended to builtins.pyi" diff --git a/src/simtradelab/backtest/config.py b/src/simtradelab/backtest/config.py index a3eb98b..b5b7a7f 100644 --- a/src/simtradelab/backtest/config.py +++ b/src/simtradelab/backtest/config.py @@ -41,6 +41,9 @@ class BacktestConfig(BaseModel): initial_capital: float = Field(default=100000.0, gt=0, description="初始资金必须大于0") use_data_server: bool = True + # 回测频率配置 + frequency: str = Field(default='1d', description="回测频率: '1d'日线, '1m'分钟线") + # 性能优化配置 enable_multiprocessing: bool = True num_workers: Optional[int] = Field(default=None, ge=1, description="多进程worker数量") diff --git a/src/simtradelab/backtest/optimizer_framework.py b/src/simtradelab/backtest/optimizer_framework.py index f960f01..19b4724 100644 --- a/src/simtradelab/backtest/optimizer_framework.py +++ b/src/simtradelab/backtest/optimizer_framework.py @@ -111,7 +111,7 @@ def suggest_parameters(cls, trial: optuna.Trial) -> dict[str, Any]: trial: Optuna trial对象 Returns: - Dict[str, Any]: 参数字典 + dict[str, Any]: 参数字典 """ choices = cls.get_parameter_choices() return { @@ -120,7 +120,7 @@ def suggest_parameters(cls, trial: optuna.Trial) -> dict[str, Any]: } @classmethod - def get_extreme_params(cls) -> Dict[str, Tuple[Any, Any]]: + def get_extreme_params(cls) -> dict[str, tuple[Any, Any]]: """自动推导极端参数范围(框架实现,子类无需覆盖) 设计理念: @@ -129,7 +129,7 @@ def get_extreme_params(cls) -> Dict[str, Tuple[Any, Any]]: - 避免人为限制参数搜索空间 Returns: - Dict[str, Tuple]: {param_name: (min_value, max_value)} + dict[str, tuple]: {param_name: (min_value, max_value)} """ choices = cls.get_parameter_choices() extreme_params = {} @@ -140,7 +140,7 @@ def get_extreme_params(cls) -> Dict[str, Tuple[Any, Any]]: return extreme_params @staticmethod - def validate(params: Dict[str, Any]) -> Dict[str, Any]: + def validate(params: dict[str, Any]) -> dict[str, Any]: """验证参数(可选,子类可覆盖) 默认实现:不做任何验证,直接返回原参数 @@ -153,7 +153,7 @@ def validate(params: Dict[str, Any]) -> Dict[str, Any]: params: 参数字典 Returns: - Dict[str, Any]: 验证后的参数字典 + dict[str, Any]: 验证后的参数字典 Raises: ValueError: 参数不合法时抛出 @@ -170,7 +170,7 @@ def validate(params): # ==================== 参数映射辅助函数 ==================== -def resolve_variable_name(param_name: str, custom_mapping: Optional[Dict[str, str]] = None) -> str: +def resolve_variable_name(param_name: str, custom_mapping: Optional[dict[str, str]] = None) -> str: """解析参数对应的策略变量名 Args: @@ -195,8 +195,8 @@ def resolve_variable_name(param_name: str, custom_mapping: Optional[Dict[str, st def apply_parameter_replacement( original_code: str, - params: Dict[str, Any], - custom_mapping: Optional[Dict[str, str]] = None + params: dict[str, Any], + custom_mapping: Optional[dict[str, str]] = None ) -> str: """统一的参数替换逻辑(消除代码重复) @@ -235,7 +235,7 @@ class ScoringStrategy: """评分策略基类""" @staticmethod - def calculate_score(metrics: Dict[str, float]) -> float: + def calculate_score(metrics: dict[str, float]) -> float: """计算综合得分(提供默认实现,子类可选覆盖) 默认策略(改进版,避免指标冗余): @@ -264,7 +264,7 @@ class MyStrategy(ScoringStrategy): # 自定义评分 class MyStrategy(ScoringStrategy): @staticmethod - def calculate_score(metrics: Dict[str, float]) -> float: + def calculate_score(metrics: dict[str, float]) -> float: return metrics['annual_return'] * 0.5 + metrics['sharpe_ratio'] * 0.5 """ score = ( @@ -276,11 +276,11 @@ def calculate_score(metrics: Dict[str, float]) -> float: return score @staticmethod - def get_tracked_metrics() -> List[str]: + def get_tracked_metrics() -> list[str]: """获取需要跟踪的指标列表(可选) Returns: - List[str]: 指标名称列表 + list[str]: 指标名称列表 """ return [ 'total_return', 'annual_return', 'sharpe_ratio', @@ -289,7 +289,7 @@ def get_tracked_metrics() -> List[str]: ] @staticmethod - def calculate_regularization_penalty(params: Dict[str, Any], extreme_params: Optional[Dict[str, Tuple[float, float]]] = None) -> float: + def calculate_regularization_penalty(params: dict[str, Any], extreme_params: Optional[dict[str, tuple[float, float]]] = None) -> float: """计算正则化惩罚(防止参数极值) Args: @@ -337,7 +337,7 @@ def __init__( start_date: str = DEFAULT_START_DATE, end_date: str = DEFAULT_END_DATE, initial_capital: float = DEFAULT_INITIAL_CAPITAL, - custom_mapping: Optional[Dict[str, str]] = None, + custom_mapping: Optional[dict[str, str]] = None, use_walk_forward: bool = True, train_months: int = DEFAULT_TRAIN_MONTHS, test_months: int = DEFAULT_TEST_MONTHS, @@ -409,7 +409,7 @@ def __init__( self._no_improvement_count = 0 # 缓存Walk-Forward时间窗口(避免每个trial重复计算) - self._cached_time_windows: Optional[List[Tuple[str, str, str, str]]] = None + self._cached_time_windows: Optional[list[tuple[str, str, str, str]]] = None if self.use_walk_forward: self._cached_time_windows = self._generate_time_windows() @@ -424,12 +424,12 @@ def original_strategy_code(self) -> str: self._cached_strategy_code = f.read() return self._cached_strategy_code - def create_strategy_code(self, params: Dict[str, Any]) -> str: + def create_strategy_code(self, params: dict[str, Any]) -> str: """基于参数创建策略代码""" # 使用统一的参数替换函数(使用缓存的策略代码) return apply_parameter_replacement(self.original_strategy_code, params, self.custom_mapping) - def run_backtest_with_params(self, params: Dict[str, Any], start_date: Optional[str] = None, end_date: Optional[str] = None) -> Tuple[float, Dict[str, Any]]: + def run_backtest_with_params(self, params: dict[str, Any], start_date: Optional[str] = None, end_date: Optional[str] = None) -> tuple[float, dict[str, Any]]: """使用给定参数运行回测(支持缓存)""" import hashlib @@ -459,7 +459,7 @@ def run_backtest_with_params(self, params: Dict[str, Any], start_date: Optional[ return result - def _run_backtest_impl(self, params: Dict[str, Any], start_date: Optional[str] = None, end_date: Optional[str] = None) -> Tuple[float, Dict[str, Any]]: + def _run_backtest_impl(self, params: dict[str, Any], start_date: Optional[str] = None, end_date: Optional[str] = None) -> tuple[float, dict[str, Any]]: """实际执行回测的内部方法""" temp_strategy_dir = None try: @@ -534,7 +534,7 @@ def _generate_time_windows(self) -> list[tuple[str, str, str, str]]: """生成Walk-Forward时间窗口 Returns: - List[Tuple]: [(train_start, train_end, test_start, test_end), ...] + list[tuple]: [(train_start, train_end, test_start, test_end), ...] """ from dateutil.relativedelta import relativedelta @@ -806,7 +806,7 @@ def update_progress(study, trial): return study - def validate_on_holdout(self, best_params: Dict[str, Any], holdout_start: str, holdout_end: str) -> Dict[str, float]: + def validate_on_holdout(self, best_params: dict[str, Any], holdout_start: str, holdout_end: str) -> dict[str, float]: """在留存集上验证最佳参数 Args: @@ -815,7 +815,7 @@ def validate_on_holdout(self, best_params: Dict[str, Any], holdout_start: str, h holdout_end: 留存集结束日期 Returns: - Dict[str, float]: 留存集指标 + dict[str, float]: 留存集指标 """ print(f"\n样本外验证: {holdout_start} 至 {holdout_end}") score, metrics = self.run_backtest_with_params(best_params, holdout_start, holdout_end) @@ -973,7 +973,7 @@ def create_optimized_strategy( best_params_file: str, original_strategy_path: str, output_path: str, - custom_mapping: Optional[Dict[str, str]] = None + custom_mapping: Optional[dict[str, str]] = None ): """基于最佳参数创建优化后的策略文件""" # 读取最佳参数 @@ -997,16 +997,16 @@ def create_optimized_strategy( # ==================== 简化的顶层API ==================== def optimize_strategy( parameter_space: type, - optimization_period: Optional[Tuple[str, str]] = None, - holdout_period: Optional[Tuple[str, str]] = None, + optimization_period: Optional[tuple[str, str]] = None, + holdout_period: Optional[tuple[str, str]] = None, initial_capital: float = DEFAULT_INITIAL_CAPITAL, scoring_strategy: Optional[Type[ScoringStrategy]] = None, - walk_forward_config: Optional[Dict[str, int]] = None, + walk_forward_config: Optional[dict[str, int]] = None, use_optimal_stopping: bool = DEFAULT_USE_OPTIMAL_STOPPING, patience: Optional[int] = None, regularization_weight: float = DEFAULT_REGULARIZATION_WEIGHT, stability_weight: float = DEFAULT_STABILITY_WEIGHT, - custom_mapping: Optional[Dict[str, str]] = None, + custom_mapping: Optional[dict[str, str]] = None, resume: bool = True, verbose: bool = False ): diff --git a/src/simtradelab/backtest/runner.py b/src/simtradelab/backtest/runner.py index 6844f10..2e39eb3 100644 --- a/src/simtradelab/backtest/runner.py +++ b/src/simtradelab/backtest/runner.py @@ -19,7 +19,8 @@ import os from simtradelab.ptrade.context import Context -from simtradelab.ptrade.object import Global, Portfolio, BacktestContext +from simtradelab.ptrade.object import Portfolio, BacktestContext +from simtradelab.ptrade.config_manager import config as ptrade_config from simtradelab.backtest.stats import generate_backtest_report, generate_backtest_charts, print_backtest_report from simtradelab.ptrade.api import PtradeAPI from simtradelab.service.data_server import DataServer @@ -40,6 +41,7 @@ def __init__(self): # 数据容器(延迟加载) self._data_loaded = False self.stock_data_dict = None + self.stock_data_dict_1m = None self.valuation_dict = None self.fundamentals_dict = None self.exrights_dict = None @@ -76,7 +78,7 @@ def run(self, config: BacktestConfig) -> dict: if not is_valid: print("\n策略验证失败:") for error in errors: - print(" - {}".format(error)) + print(f" - {error}") return {} if fixed_code: print("已自动修复Python 3.5兼容性问题") @@ -105,10 +107,7 @@ def run(self, config: BacktestConfig) -> dict: try: # 加载数据 - benchmark_df = self._load_data(required_data) - - # 创建全局对象(每次run都新建,避免状态污染) - g = Global() + benchmark_df = self._load_data(required_data, config.frequency) # 初始化日志 self._setup_logging(config) @@ -136,8 +135,8 @@ def run(self, config: BacktestConfig) -> dict: context=context, api=api, stats_collector=stats_collector, - g=g, - log=log + log=log, + frequency=config.frequency ) # 加载策略 @@ -166,11 +165,12 @@ def run(self, config: BacktestConfig) -> dict: self._cleanup() @timer(name="数据加载") - def _load_data(self, required_data=None) -> pd.DataFrame: + def _load_data(self, required_data=None, frequency='1d') -> pd.DataFrame: """加载数据 Args: required_data: 需要加载的数据集合 + frequency: 回测频率 '1d'日线 '1m'分钟线 Returns: 基准数据DataFrame @@ -182,10 +182,11 @@ def _load_data(self, required_data=None) -> pd.DataFrame: return next(iter(self.benchmark_data.values())) # type: ignore # 使用多进程安全的DataServer - data_server = DataServer(required_data) + data_server = DataServer(required_data, frequency) # 绑定到runner实例 self.stock_data_dict = data_server.stock_data_dict + self.stock_data_dict_1m = data_server.stock_data_dict_1m self.valuation_dict = data_server.valuation_dict self.fundamentals_dict = data_server.fundamentals_dict self.exrights_dict = data_server.exrights_dict @@ -211,11 +212,13 @@ def _setup_logging(self, config: BacktestConfig): handlers = [logging.StreamHandler(sys.stdout)] # 仅在启用日志时创建文件handler + os.makedirs(config.log_dir, exist_ok=True) if config.enable_logging: - log_filename = config.get_log_filename() - os.makedirs(config.log_dir, exist_ok=True) - print(f"日志文件: {log_filename}") - handlers.append(logging.FileHandler(log_filename, mode='w', encoding='utf-8')) + self._log_filename = config.get_log_filename() + print(f"日志文件: {self._log_filename}") + handlers.append(logging.FileHandler(self._log_filename, mode='w', encoding='utf-8')) + if config.enable_charts: + self._chart_filename = config.get_chart_filename() logging.basicConfig( level=logging.INFO, @@ -253,9 +256,13 @@ def _initialize_context(self, config: BacktestConfig, start_date, log) -> tuple: """ from simtradelab.ptrade.data_context import DataContext + # 重置全局交易配置,避免前次回测的残留设置污染本次回测 + ptrade_config.reset_to_defaults() + # 创建组合和上下文 portfolio = Portfolio(config.initial_capital) - context = Context(portfolio=portfolio, current_dt=start_date) + context = Context(portfolio=portfolio, current_dt=start_date, + frequency=config.frequency) # 设置portfolio的context引用 portfolio._context = context @@ -273,7 +280,8 @@ def _initialize_context(self, config: BacktestConfig, start_date, log) -> tuple: adj_pre_cache=self.adj_pre_cache, adj_post_cache=self.adj_post_cache, dividend_cache=self.dividend_cache, - trade_days=self.trade_days + trade_days=self.trade_days, + stock_data_dict_1m=self.stock_data_dict_1m ) # 创建API @@ -379,18 +387,17 @@ def _generate_reports( # 生成图表 if config.enable_charts: chart_benchmark_data = {benchmark_code: actual_benchmark_df} - chart_filename = generate_backtest_charts( + generate_backtest_charts( stats, config.start_date, config.end_date, chart_benchmark_data, - config.get_chart_filename(), + self._chart_filename, benchmark_code=benchmark_code ) - print(f"图表已保存至: {chart_filename}") + print(f"图表已保存至: {self._chart_filename}") if config.enable_logging: - log_filename = config.get_log_filename() - print(f"\n日志已保存至: {log_filename}") + print(f"\n日志已保存至: {self._log_filename}") return report diff --git a/src/simtradelab/backtest/stats.py b/src/simtradelab/backtest/stats.py index 62e6f49..4e75952 100644 --- a/src/simtradelab/backtest/stats.py +++ b/src/simtradelab/backtest/stats.py @@ -15,8 +15,6 @@ import os import json import numpy as np -import matplotlib.pyplot as plt -import matplotlib.dates as mdates def _load_index_names(): @@ -33,28 +31,15 @@ def _load_index_names(): return {} -def _get_benchmark_name(benchmark_code, use_english=False): +def _get_benchmark_name(benchmark_code): """获取基准名称 Args: benchmark_code: 基准代码 - use_english: 是否使用英文名称(用于图表显示) Returns: str: 基准名称,如果找不到则返回代码本身 """ - if use_english: - english_names = { - '000300.SS': 'CSI 300', - '000905.SZ': 'CSI 500', - '000001.SZ': 'SZE Component', - '399001.SZ': 'SZE Component', - '399006.SZ': 'ChiNext', - '399101.SZ': 'SME Board', - '000001.SS': 'SSE Composite' - } - return english_names.get(benchmark_code, benchmark_code) - index_names = _load_index_names() return index_names.get(benchmark_code, benchmark_code) @@ -349,10 +334,10 @@ def _plot_nav_curve(ax, dates, portfolio_values, daily_buy, daily_sell, benchmar """ # 策略净值曲线 strategy_nav = portfolio_values / portfolio_values[0] - ax.plot(dates, strategy_nav, linewidth=2, label='Strategy NAV', color='#1f77b4') + ax.plot(dates, strategy_nav, linewidth=2, label='策略净值', color='#1f77b4') # 基准净值曲线 - benchmark_name = _get_benchmark_name(benchmark_code, use_english=True) + benchmark_name = _get_benchmark_name(benchmark_code) if benchmark_code in benchmark_data and not benchmark_data[benchmark_code].empty: benchmark_df_data = benchmark_data[benchmark_code] benchmark_slice = benchmark_df_data.loc[ @@ -367,35 +352,36 @@ def _plot_nav_curve(ax, dates, portfolio_values, daily_buy, daily_sell, benchmar # 标注买卖点 buy_dates = dates[daily_buy > 0] buy_navs = strategy_nav[daily_buy > 0] - ax.scatter(buy_dates, buy_navs, marker='^', color='red', s=50, label='Buy', zorder=5) + ax.scatter(buy_dates, buy_navs, marker='^', color='red', s=50, label='买入', zorder=5) sell_dates = dates[daily_sell > 0] sell_navs = strategy_nav[daily_sell > 0] - ax.scatter(sell_dates, sell_navs, marker='v', color='green', s=50, label='Sell', zorder=5) + ax.scatter(sell_dates, sell_navs, marker='v', color='green', s=50, label='卖出', zorder=5) - ax.set_title('Portfolio Value vs Benchmark', fontsize=14, fontweight='bold') - ax.set_ylabel('Net Asset Value', fontsize=12) + ax.set_title('策略净值 vs 基准', fontsize=14, fontweight='bold') + ax.set_ylabel('净值', fontsize=12) ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) -def _plot_daily_pnl(ax, dates, daily_pnl): +def _plot_daily_pnl(ax, dates, daily_pnl, bar_width): """绘制每日盈亏子图 Args: ax: matplotlib axes对象 dates: 日期数组 daily_pnl: 每日盈亏数组 + bar_width: 柱宽 """ colors = ['red' if pnl >= 0 else 'green' for pnl in daily_pnl] - ax.bar(dates, daily_pnl, color=colors, alpha=0.7, width=0.8) - ax.axhline(y=0, color='black', linestyle='-', linewidth=1) - ax.set_title('Daily P&L', fontsize=14, fontweight='bold') - ax.set_ylabel('P&L (CNY)', fontsize=12) + ax.bar(dates, daily_pnl, color=colors, alpha=0.7, width=bar_width) + ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5) + ax.set_title('每日盈亏', fontsize=14, fontweight='bold') + ax.set_ylabel('盈亏(元)', fontsize=12) ax.grid(True, alpha=0.3, axis='y') -def _plot_trade_amounts(ax, dates, daily_buy, daily_sell): +def _plot_trade_amounts(ax, dates, daily_buy, daily_sell, bar_width): """绘制交易金额子图 Args: @@ -403,13 +389,13 @@ def _plot_trade_amounts(ax, dates, daily_buy, daily_sell): dates: 日期数组 daily_buy: 每日买入金额 daily_sell: 每日卖出金额 + bar_width: 柱宽 """ - width = 0.4 - ax.bar(dates, daily_buy, color='red', alpha=0.7, width=width, label='Buy Amount') - ax.bar(dates, -daily_sell, color='green', alpha=0.7, width=width, label='Sell Amount') - ax.axhline(y=0, color='black', linestyle='-', linewidth=1) - ax.set_title('Daily Buy/Sell Amount', fontsize=14, fontweight='bold') - ax.set_ylabel('Amount (CNY)', fontsize=12) + ax.bar(dates, daily_buy, color='red', alpha=0.7, width=bar_width, label='买入金额') + ax.bar(dates, -daily_sell, color='green', alpha=0.7, width=bar_width, label='卖出金额') + ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5) + ax.set_title('每日买卖金额', fontsize=14, fontweight='bold') + ax.set_ylabel('金额(元)', fontsize=12) ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3, axis='y') @@ -423,10 +409,10 @@ def _plot_positions_value(ax, dates, daily_positions_val): daily_positions_val: 每日持仓市值数组 """ ax.fill_between(dates, daily_positions_val, alpha=0.3, color='#9467bd') - ax.plot(dates, daily_positions_val, linewidth=2, color='#9467bd', label='Positions Value') - ax.set_title('Daily Positions Value', fontsize=14, fontweight='bold') - ax.set_xlabel('Date', fontsize=12) - ax.set_ylabel('Value (CNY)', fontsize=12) + ax.plot(dates, daily_positions_val, linewidth=2, color='#9467bd', label='持仓市值') + ax.set_title('每日持仓市值', fontsize=14, fontweight='bold') + ax.set_xlabel('日期', fontsize=12) + ax.set_ylabel('市值(元)', fontsize=12) ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) @@ -445,20 +431,26 @@ def generate_backtest_charts(backtest_stats, start_date, end_date, benchmark_dat Returns: str: 图表文件路径 """ + import matplotlib.pyplot as plt + import matplotlib.dates as mdates + # 设置字体 - 使用系统可用字体 - plt.rcParams['font.sans-serif'] = ['Ubuntu', 'DejaVu Sans'] + plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Ubuntu', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False # 验证并提取数据 dates, portfolio_values, daily_pnl, daily_buy, daily_sell, daily_positions_val = _validate_chart_data(backtest_stats) + # 统一柱宽 + bar_width = 4 + # 创建图表 - 4行1列布局 - fig, axes = plt.subplots(4, 1, figsize=(16, 20), sharex=True) + _, axes = plt.subplots(4, 1, figsize=(16, 20), sharex=True) # 绘制4个子图 _plot_nav_curve(axes[0], dates, portfolio_values, daily_buy, daily_sell, benchmark_data, start_date, end_date, benchmark_code) - _plot_daily_pnl(axes[1], dates, daily_pnl) - _plot_trade_amounts(axes[2], dates, daily_buy, daily_sell) + _plot_daily_pnl(axes[1], dates, daily_pnl, bar_width) + _plot_trade_amounts(axes[2], dates, daily_buy, daily_sell, bar_width) _plot_positions_value(axes[3], dates, daily_positions_val) # 设置x轴日期格式 @@ -467,15 +459,12 @@ def generate_backtest_charts(backtest_stats, start_date, end_date, benchmark_dat ax.xaxis.set_major_locator(mdates.MonthLocator()) plt.setp(ax.xaxis.get_majorticklabels(), rotation=45) - # 调整布局 - plt.tight_layout() - # 自动创建目录 chart_dir = os.path.dirname(chart_filename) os.makedirs(chart_dir, exist_ok=True) - # 保存图表 - plt.savefig(chart_filename, dpi=150, bbox_inches='tight') + # 保存图表(tight_layout 由 bbox_inches='tight' 隐式完成) + plt.savefig(chart_filename, dpi=100, bbox_inches='tight') plt.close() return chart_filename diff --git a/src/simtradelab/backtest/stats_collector.py b/src/simtradelab/backtest/stats_collector.py index f20411b..de303dd 100644 --- a/src/simtradelab/backtest/stats_collector.py +++ b/src/simtradelab/backtest/stats_collector.py @@ -42,11 +42,12 @@ def collect_pre_trading(self, context: Context, current_date): ) self._stats['trade_dates'].append(current_date) - def collect_trading_amounts(self, prev_cash: float, current_cash: float): - """收集交易金额""" - cash_change = current_cash - prev_cash - self._stats['daily_buy_amount'].append(max(0, -cash_change)) - self._stats['daily_sell_amount'].append(max(0, cash_change)) + def collect_trading_amounts(self, context: Context): + """收集交易金额(从OrderProcessor累计的gross金额)""" + self._stats['daily_buy_amount'].append(context._daily_buy_total) + self._stats['daily_sell_amount'].append(context._daily_sell_total) + context._daily_buy_total = 0.0 + context._daily_sell_total = 0.0 def collect_post_trading(self, context: Context, prev_portfolio_value: float): """收集交易后数据""" diff --git a/src/simtradelab/cli/data_tools.py b/src/simtradelab/cli/data_tools.py index 5c7e1a7..bd886a4 100644 --- a/src/simtradelab/cli/data_tools.py +++ b/src/simtradelab/cli/data_tools.py @@ -49,8 +49,8 @@ def unpack_all(self, download_dir, verify=True): with open(manifest_file, 'r') as f: manifest = json.load(f) - print("数据版本: {}".format(manifest['version'])) - print("导出日期: {}".format(manifest['export_date'])) + print(f"数据版本: {manifest['version']}") + print(f"导出日期: {manifest['export_date']}") print("=" * 70) # 确保数据目录存在 @@ -61,7 +61,7 @@ def unpack_all(self, download_dir, verify=True): pkg_path = download_dir / pkg_info['name'] if not pkg_path.exists(): - print("\n警告:文件不存在 {}".format(pkg_info['name'])) + print(f"\n警告:文件不存在 {pkg_info['name']}") continue # 解压 @@ -81,7 +81,7 @@ def unpack_all(self, download_dir, verify=True): json.dump(version_info, f, ensure_ascii=False, indent=2) print("\n" + "=" * 70) - print("解包完成!数据目录: {}".format(self.data_dir)) + print(f"解包完成!数据目录: {self.data_dir}") print("=" * 70) @@ -116,5 +116,5 @@ def unpack_command(download_dir, data_dir=None): sys.exit(1) unpack_command(sys.argv[2]) else: - print("未知命令: {}".format(command)) + print(f"未知命令: {command}") sys.exit(1) diff --git a/src/simtradelab/ptrade/__init__.py b/src/simtradelab/ptrade/__init__.py index 5206ac1..9d01717 100644 --- a/src/simtradelab/ptrade/__init__.py +++ b/src/simtradelab/ptrade/__init__.py @@ -16,7 +16,6 @@ from .context import ( Context, PTradeMode, - StrategyLifecycleManager, create_backtest_context, create_research_context, create_trading_context, @@ -51,7 +50,6 @@ BacktestContext, Blotter, Data, - Global, LazyDataDict, Order, Portfolio, @@ -76,7 +74,6 @@ # Context related "Context", "PTradeMode", - "StrategyLifecycleManager", "create_backtest_context", "create_research_context", "create_trading_context", @@ -102,7 +99,6 @@ "BacktestContext", "Blotter", "Data", - "Global", "LazyDataDict", "Order", "Portfolio", diff --git a/src/simtradelab/ptrade/adj_cache.py b/src/simtradelab/ptrade/adj_cache.py index 7c8764c..62e8f50 100644 --- a/src/simtradelab/ptrade/adj_cache.py +++ b/src/simtradelab/ptrade/adj_cache.py @@ -29,10 +29,9 @@ def _calculate_adj_factors_from_events(stock, stock_df, exrights_events): - 最新的价格不变 (adj_a=1, adj_b=0) - 越往过去,调整幅度越大 - 除权除息调整规则(从最新往历史回推): - 每遇到一个除权日,该日期之前的价格需要调整: - - adj_a[旧] = adj_a[新] × (1 + allotted_ps + rationed_ps) - - adj_b[旧] = adj_b[新] × (1 + allotted_ps + rationed_ps) + bonus_ps - rationed_ps × rationed_px + 如果除权数据包含平台预计算的 exer_forward_a/b,直接使用(精度更高)。 + 平台公式: P_adj = exer_forward_a * P + exer_forward_b + 转换关系: adj_a = 1/exer_forward_a, adj_b = -exer_forward_b/exer_forward_a """ if stock_df is None or stock_df.empty: return None @@ -48,33 +47,48 @@ def _calculate_adj_factors_from_events(stock, stock_df, exrights_events): try: ex_dates_int = exrights_events.index.tolist() ex_dates_dt = pd.to_datetime(ex_dates_int, format="%Y%m%d") + n_events = len(ex_dates_int) - # 获取除权除息数据 - allotted_ps = exrights_events["allotted_ps"].values # 送转股比例 - bonus_ps = exrights_events["bonus_ps"].values # 现金分红(元/股) - rationed_ps = exrights_events["rationed_ps"].values # 配股比例 - rationed_px = exrights_events["rationed_px"].values # 配股价格 - - # 计算每个除权日之后的前复权因子 - n_events = len(allotted_ps) - forward_a_array = np.ones(n_events + 1, dtype="float64") - forward_b_array = np.zeros(n_events + 1, dtype="float64") - - # 最新时刻(index=n_events): 不调整 - forward_a_array[n_events] = 1.0 - forward_b_array[n_events] = 0.0 - - # 从最新往历史回推计算前复权因子 - for i in range(n_events - 1, -1, -1): - total_ratio = allotted_ps[i] + rationed_ps[i] - multiplier = 1.0 + total_ratio + has_platform_factors = ( + 'exer_forward_a' in exrights_events.columns + and exrights_events['exer_forward_a'].notna().all() + ) - forward_a_array[i] = forward_a_array[i + 1] * multiplier - forward_b_array[i] = ( - forward_b_array[i + 1] * multiplier - + bonus_ps[i] - - rationed_ps[i] * rationed_px[i] - ) + if has_platform_factors: + # 直接使用平台预计算的前复权因子 + ef_a = exrights_events['exer_forward_a'].values + ef_b = exrights_events['exer_forward_b'].values + + forward_a_array = np.empty(n_events + 1, dtype="float64") + forward_b_array = np.empty(n_events + 1, dtype="float64") + + for i in range(n_events): + forward_a_array[i] = 1.0 / ef_a[i] + forward_b_array[i] = -ef_b[i] / ef_a[i] + + # 最新时刻不调整 + forward_a_array[n_events] = 1.0 + forward_b_array[n_events] = 0.0 + else: + # 从原始事件计算 + allotted_ps = exrights_events["allotted_ps"].values + bonus_ps = exrights_events["bonus_ps"].values + rationed_ps = exrights_events["rationed_ps"].values + rationed_px = exrights_events["rationed_px"].values + + forward_a_array = np.ones(n_events + 1, dtype="float64") + forward_b_array = np.zeros(n_events + 1, dtype="float64") + + for i in range(n_events - 1, -1, -1): + total_ratio = allotted_ps[i] + rationed_ps[i] + multiplier = 1.0 + total_ratio + + forward_a_array[i] = forward_a_array[i + 1] * multiplier + forward_b_array[i] = ( + forward_b_array[i + 1] * multiplier + + bonus_ps[i] + - rationed_ps[i] * rationed_px[i] + ) # 向量化操作 trade_dates_np = stock_df.index.values @@ -97,13 +111,13 @@ def _calculate_adj_factors_from_events(stock, stock_df, exrights_events): except (ValueError, KeyError, IndexError, pd.errors.EmptyDataError) as e: import logging logger = logging.getLogger(__name__) - logger.error("计算 {} 前复权因子失败: {}".format(stock, e)) + logger.error(f"计算 {stock} 前复权因子失败: {e}") return None except Exception as e: import logging import traceback logger = logging.getLogger(__name__) - logger.error("计算 {} 前复权因子失败: {}".format(stock, e)) + logger.error(f"计算 {stock} 前复权因子失败: {e}") logger.debug(traceback.format_exc()) return None @@ -181,11 +195,9 @@ def create_adj_pre_cache(data_context): if not ex_df.empty: exrights_cache[stock] = ex_df - logger.info(" 已加载 {} 只股票的除权数据".format(len(exrights_cache))) + logger.info(f" 已加载 {len(exrights_cache)} 只股票的除权数据") - logger.info(" 并行计算前复权因子({} 进程)...".format( - num_workers if num_workers > 0 else "auto" - )) + logger.info(f" 并行计算前复权因子({num_workers if num_workers > 0 else 'auto'} 进程)...") results = Parallel(n_jobs=num_workers, backend="loky", verbose=0)( delayed(_calculate_adj_factors_from_events)( @@ -212,17 +224,17 @@ def create_adj_pre_cache(data_context): file_size = os.path.getsize(ADJ_PRE_CACHE_PATH) / 1024 / 1024 logger.info("✓ 前复权因子缓存创建完成!") - logger.info(" 处理: {} 只股票".format(total_stocks)) - logger.info(" 保存: {} 只(有除权数据或价格数据)".format(saved_count)) + logger.info(f" 处理: {total_stocks} 只股票") + logger.info(f" 保存: {saved_count} 只(有除权数据或价格数据)") if failed_stocks: - logger.warning(" 失败股票: {} 只".format(len(failed_stocks))) - logger.info(" 文件: {} ({:.1f}MB)".format(ADJ_PRE_CACHE_PATH, file_size)) + logger.warning(f" 失败股票: {len(failed_stocks)} 只") + logger.info(f" 文件: {ADJ_PRE_CACHE_PATH} ({file_size:.1f}MB)") except OSError as e: - logger.error("创建前复权因子缓存失败: {}".format(e)) + logger.error(f"创建前复权因子缓存失败: {e}") raise except Exception as e: - logger.error("创建前复权因子缓存时发生未预期错误: {}".format(e)) + logger.error(f"创建前复权因子缓存时发生未预期错误: {e}") import traceback logger.debug(traceback.format_exc()) raise @@ -241,7 +253,7 @@ def load_adj_pre_cache(data_context): try: create_adj_pre_cache(data_context) except Exception as e: - logger.error("创建前复权因子缓存失败: {}".format(e)) + logger.error(f"创建前复权因子缓存失败: {e}") raise logger.info("正在加载前复权因子缓存...") @@ -252,22 +264,22 @@ def load_adj_pre_cache(data_context): if adj_factors_cache is None: raise FileNotFoundError("缓存文件为空") - logger.info("✓ 前复权因子缓存加载完成!共 {} 只股票".format(len(adj_factors_cache))) + logger.info(f"✓ 前复权因子缓存加载完成!共 {len(adj_factors_cache)} 只股票") return adj_factors_cache except FileNotFoundError: - logger.error("缓存文件不存在: {}".format(ADJ_PRE_CACHE_PATH)) + logger.error(f"缓存文件不存在: {ADJ_PRE_CACHE_PATH}") create_adj_pre_cache(data_context) return load_adj_pre_cache(data_context) except Exception as e: - logger.error("缓存文件损坏或格式错误: {}".format(e)) + logger.error(f"缓存文件损坏或格式错误: {e}") try: os.remove(ADJ_PRE_CACHE_PATH) logger.info("已删除损坏的缓存文件,重新创建...") create_adj_pre_cache(data_context) return load_adj_pre_cache(data_context) except OSError as remove_error: - logger.error("删除损坏的缓存文件失败: {}".format(remove_error)) + logger.error(f"删除损坏的缓存文件失败: {remove_error}") raise @@ -336,13 +348,13 @@ def _calculate_adj_post_factors_from_events(stock, stock_df, exrights_events): except (ValueError, KeyError, IndexError, pd.errors.EmptyDataError) as e: import logging logger = logging.getLogger(__name__) - logger.error("计算 {} 后复权因子失败: {}".format(stock, e)) + logger.error(f"计算 {stock} 后复权因子失败: {e}") return None except Exception as e: import logging import traceback logger = logging.getLogger(__name__) - logger.error("计算 {} 后复权因子失败: {}".format(stock, e)) + logger.error(f"计算 {stock} 后复权因子失败: {e}") logger.debug(traceback.format_exc()) return None @@ -377,11 +389,9 @@ def create_adj_post_cache(data_context): if not ex_df.empty: exrights_cache[stock] = ex_df - logger.info(" 已加载 {} 只股票的除权数据".format(len(exrights_cache))) + logger.info(f" 已加载 {len(exrights_cache)} 只股票的除权数据") - logger.info(" 并行计算后复权因子({} 进程)...".format( - num_workers if num_workers > 0 else "auto" - )) + logger.info(f" 并行计算后复权因子({num_workers if num_workers > 0 else 'auto'} 进程)...") results = Parallel(n_jobs=num_workers, backend="loky", verbose=0)( delayed(_calculate_adj_post_factors_from_events)( @@ -408,17 +418,17 @@ def create_adj_post_cache(data_context): file_size = os.path.getsize(ADJ_POST_CACHE_PATH) / 1024 / 1024 logger.info("✓ 后复权因子缓存创建完成!") - logger.info(" 处理: {} 只股票".format(total_stocks)) - logger.info(" 保存: {} 只(有除权数据或价格数据)".format(saved_count)) + logger.info(f" 处理: {total_stocks} 只股票") + logger.info(f" 保存: {saved_count} 只(有除权数据或价格数据)") if failed_stocks: - logger.warning(" 失败股票: {} 只".format(len(failed_stocks))) - logger.info(" 文件: {} ({:.1f}MB)".format(ADJ_POST_CACHE_PATH, file_size)) + logger.warning(f" 失败股票: {len(failed_stocks)} 只") + logger.info(f" 文件: {ADJ_POST_CACHE_PATH} ({file_size:.1f}MB)") except OSError as e: - logger.error("创建后复权因子缓存失败: {}".format(e)) + logger.error(f"创建后复权因子缓存失败: {e}") raise except Exception as e: - logger.error("创建后复权因子缓存时发生未预期错误: {}".format(e)) + logger.error(f"创建后复权因子缓存时发生未预期错误: {e}") import traceback logger.debug(traceback.format_exc()) raise @@ -437,7 +447,7 @@ def load_adj_post_cache(data_context): try: create_adj_post_cache(data_context) except Exception as e: - logger.error("创建后复权因子缓存失败: {}".format(e)) + logger.error(f"创建后复权因子缓存失败: {e}") raise logger.info("正在加载后复权因子缓存...") @@ -448,22 +458,22 @@ def load_adj_post_cache(data_context): if adj_factors_cache is None: raise FileNotFoundError("缓存文件为空") - logger.info("✓ 后复权因子缓存加载完成!共 {} 只股票".format(len(adj_factors_cache))) + logger.info(f"✓ 后复权因子缓存加载完成!共 {len(adj_factors_cache)} 只股票") return adj_factors_cache except FileNotFoundError: - logger.error("缓存文件不存在: {}".format(ADJ_POST_CACHE_PATH)) + logger.error(f"缓存文件不存在: {ADJ_POST_CACHE_PATH}") create_adj_post_cache(data_context) return load_adj_post_cache(data_context) except Exception as e: - logger.error("缓存文件损坏或格式错误: {}".format(e)) + logger.error(f"缓存文件损坏或格式错误: {e}") try: os.remove(ADJ_POST_CACHE_PATH) logger.info("已删除损坏的缓存文件,重新创建...") create_adj_post_cache(data_context) return load_adj_post_cache(data_context) except OSError as remove_error: - logger.error("删除损坏的缓存文件失败: {}".format(remove_error)) + logger.error(f"删除损坏的缓存文件失败: {remove_error}") raise diff --git a/src/simtradelab/ptrade/api.py b/src/simtradelab/ptrade/api.py index 8d2cc50..375fc23 100644 --- a/src/simtradelab/ptrade/api.py +++ b/src/simtradelab/ptrade/api.py @@ -38,6 +38,7 @@ def validate_lifecycle(func: Callable) -> Callable: """ @wraps(func) def wrapper(self, *args, **kwargs): + controller = None # 如果context有lifecycle_controller,进行验证 if hasattr(self, 'context') and self.context and hasattr(self.context, '_lifecycle_controller'): controller = self.context._lifecycle_controller @@ -46,11 +47,15 @@ def wrapper(self, *args, **kwargs): validation_result = controller.validate_api_call(api_name) if not validation_result.is_valid: raise PTradeLifecycleError(validation_result.error_message) - # 记录API调用 - controller.record_api_call(api_name, success=True) # 执行原函数 - return func(self, *args, **kwargs) + result = func(self, *args, **kwargs) + + # 执行成功后记录 + if controller: + controller.record_api_call(func.__name__, success=True) + + return result return wrapper @@ -77,11 +82,13 @@ def __init__(self, data_context: Any, context: Any, log: Any) -> None: # 缓存 - 使用统一缓存管理器 self._stock_status_cache: dict[tuple, bool] = {} + self._stock_status_cache_max = 50000 self._stock_date_index: dict[str, tuple[dict, list]] = {} self._prebuilt_index: bool = False self._sorted_status_dates: Optional[list[str]] = None self._history_cache: dict = cache_manager.get_namespace('history')._cache # 使用LRUCache - self._fundamentals_cache: dict = cache_manager.get_namespace('fundamentals')._cache # 使用全局缓存 + self._fundamentals_cache: dict = {} # 独立管理的基本面索引缓存 + self._fundamentals_cache_max = 500 @property def order_processor(self) -> OrderProcessor: @@ -119,8 +126,6 @@ def prebuild_date_index(self, stocks: Optional[list[str]] = None) -> None: def get_stock_date_index(self, stock: str) -> tuple[dict, list]: """获取股票日期索引,返回 (date_dict, sorted_dates) 元组""" if stock not in self._stock_date_index: - # 延迟构建单只股票索引 - # 检查stock_data_dict和benchmark_data stock_df = None if stock in self.data_context.stock_data_dict: stock_df = self.data_context.stock_data_dict[stock] @@ -128,80 +133,62 @@ def get_stock_date_index(self, stock: str) -> tuple[dict, list]: stock_df = self.data_context.benchmark_data[stock] if stock_df is not None and isinstance(stock_df, pd.DataFrame) and isinstance(stock_df.index, pd.DatetimeIndex): - date_dict = {date: idx for idx, date in enumerate(stock_df.index)} + # 用 numpy 数组直接构建 dict,避免 Python 层迭代 DatetimeIndex + idx_array = stock_df.index.values # numpy datetime64 array + date_dict = {pd.Timestamp(idx_array[i]): i for i in range(len(idx_array))} sorted_dates = list(stock_df.index) self._stock_date_index[stock] = (date_dict, sorted_dates) else: self._stock_date_index[stock] = ({}, []) return self._stock_date_index.get(stock, ({}, [])) - def get_adjusted_price(self, stock: str, date: str, price_type: str = 'close', fq: str = None) -> float: - """获取复权后的价格 + def _apply_adj_factors(self, stock_df: pd.DataFrame, stock: str, fq: str) -> pd.DataFrame: + """对DataFrame应用复权因子(向量化) Args: + stock_df: 股票数据DataFrame stock: 股票代码 - date: 日期 - price_type: 价格类型 (close/open/high/low) - fq: 复权类型 (None-不复权, 'pre'-前复权, 'post'-后复权) + fq: 复权类型 ('pre'前复权, 'post'后复权) Returns: - 复权后价格 + 复权后的DataFrame(copy),无复权因子时返回原DataFrame """ - if fq is None or stock not in self.data_context.stock_data_dict: - # 不复权,直接返回原始价格 - try: - stock_df = self.data_context.stock_data_dict[stock] - return stock_df.loc[date, price_type] - except (KeyError, IndexError, AttributeError): - return np.nan - if fq == 'pre': - # 前复权:使用adj_pre_cache - try: - stock_df = self.data_context.stock_data_dict[stock] - original_price = stock_df.loc[date, price_type] - - # 使用预计算的复权因子缓存 - if self.data_context.adj_pre_cache and stock in self.data_context.adj_pre_cache: - adj_factors = self.data_context.adj_pre_cache[stock] - date_ts = pd.Timestamp(date) - if date_ts in adj_factors.index: - adj_a = adj_factors.loc[date_ts, 'adj_a'] - adj_b = adj_factors.loc[date_ts, 'adj_b'] - # 前复权公式: 前复权价 = (未复权价 - adj_b) / adj_a - return (original_price - adj_b) / adj_a - - # 缓存不存在或无对应日期,返回原始价 - return original_price - except (KeyError, IndexError, AttributeError): - return np.nan - - if fq == 'post': - # 后复权:使用adj_post_cache - try: - stock_df = self.data_context.stock_data_dict[stock] - original_price = stock_df.loc[date, price_type] - - if self.data_context.adj_post_cache and stock in self.data_context.adj_post_cache: - adj_factors = self.data_context.adj_post_cache[stock] - date_ts = pd.Timestamp(date) - if date_ts in adj_factors.index: - adj_a = adj_factors.loc[date_ts, 'adj_a'] - adj_b = adj_factors.loc[date_ts, 'adj_b'] - # 后复权公式: 后复权价 = adj_a * 未复权价 + adj_b - return adj_a * original_price + adj_b - - return original_price - except (KeyError, IndexError, AttributeError): - return np.nan - - # 其他情况返回原始价 - try: - stock_df = self.data_context.stock_data_dict[stock] - return stock_df.loc[date, price_type] - except (KeyError, IndexError, AttributeError): - return np.nan - + adj_cache = self.data_context.adj_pre_cache + elif fq == 'post': + adj_cache = self.data_context.adj_post_cache + else: + return stock_df + + if not adj_cache or stock not in adj_cache: + return stock_df + + adj_factors = adj_cache[stock] + common_idx = stock_df.index.intersection(adj_factors.index) + if len(common_idx) == 0: + return stock_df + + adjusted_df = stock_df.copy() + adj_a = adj_factors.loc[common_idx, 'adj_a'] + adj_b = adj_factors.loc[common_idx, 'adj_b'] + price_cols = ['open', 'high', 'low', 'close'] + + for col in price_cols: + if col not in adjusted_df.columns: + continue + if fq == 'pre': + # 前复权: (未复权价 - adj_b) / adj_a + adjusted_df.loc[common_idx, col] = np.round( + (adjusted_df.loc[common_idx, col] - adj_b) / adj_a, 2 + ) + else: + # 后复权: adj_a * 未复权价 + adj_b + adjusted_df.loc[common_idx, col] = ( + adj_a * adjusted_df.loc[common_idx, col] + adj_b + ) + + return adjusted_df + # ==================== 基础API ==================== def get_research_path(self) -> str: @@ -245,8 +232,9 @@ def get_trade_days(self, start_date: str = None, end_date: str = None, count: in end_date: 结束日期(默认当前回测日期) count: 往前count个交易日(与start_date二选一) """ - trade_days_df = self.data_context.stock_data_store['/trade_days'] - all_trade_days = trade_days_df.index + if self.data_context.trade_days is None: + raise RuntimeError("交易日历数据未加载") + all_trade_days = self.data_context.trade_days if end_date is None: end_dt = self.context.current_dt @@ -372,8 +360,7 @@ def get_fundamentals(self, security: str | list[str], table: str, fields: list[s # 获取或创建日期索引缓存(增量更新) if cache_key not in self._fundamentals_cache: self._fundamentals_cache[cache_key] = {} - # 限制缓存条目数量 - if len(self._fundamentals_cache) > 500: + if len(self._fundamentals_cache) > self._fundamentals_cache_max: self._fundamentals_cache.pop(next(iter(self._fundamentals_cache))) date_indices = self._fundamentals_cache[cache_key] @@ -434,7 +421,7 @@ def get_fundamentals(self, security: str | list[str], table: str, fields: list[s if isinstance(stock_df, pd.DataFrame) and not stock_df.empty: idx = stock_df.index.searchsorted(query_ts, side='right') if idx > 0: - close_prices[stock] = stock_df.iloc[idx - 1]['close'] + close_prices[stock] = stock_df['close'].values[idx - 1] self._fundamentals_cache[price_cache_key] = close_prices for stock in stocks: @@ -447,33 +434,31 @@ def get_fundamentals(self, security: str | list[str], table: str, fields: list[s continue idx = date_indices[stock] - nearest_date = df.index[idx] - row = df.loc[nearest_date] stock_data = {} for field in fields: if field == 'total_value' and need_realtime_total_value: - # 实时计算总市值 = 查询日收盘价 × 总股本 - total_shares = row.get('total_shares') + col_idx = df.columns.get_loc('total_shares') if 'total_shares' in df.columns else -1 + total_shares = df.iat[idx, col_idx] if col_idx >= 0 else None if total_shares is not None and not pd.isna(total_shares) and stock in close_prices: stock_data[field] = close_prices[stock] * total_shares - else: - stock_data[field] = row.get(field) + elif field in df.columns: + stock_data[field] = df[field].values[idx] elif field == 'float_value' and need_realtime_float_value: - # 实时计算流通市值 = 查询日收盘价 × 流通股本 - a_floats = row.get('a_floats') + col_idx = df.columns.get_loc('a_floats') if 'a_floats' in df.columns else -1 + a_floats = df.iat[idx, col_idx] if col_idx >= 0 else None if a_floats is not None and not pd.isna(a_floats) and stock in close_prices: stock_data[field] = close_prices[stock] * a_floats - else: - stock_data[field] = row.get(field) - else: - stock_data[field] = row.get(field) + elif field in df.columns: + stock_data[field] = df[field].values[idx] + elif field in df.columns: + stock_data[field] = df[field].values[idx] if stock_data: result_data[stock] = stock_data except Exception as e: - print("读取{}数据失败: stock={}, fields={}, error={}".format(table, stock, fields, e)) + print(f"读取{table}数据失败: stock={stock}, fields={fields}, error={e}") traceback.print_exc() raise @@ -517,13 +502,20 @@ def columns(self) -> pd.Index: first_df = next(iter(self.values())) return first_df.columns + def _get_data_source(self, frequency: str): + """根据frequency获取对应的数据源""" + if frequency == '1m': + if self.data_context.stock_data_dict_1m is None: + raise ValueError("分钟数据未加载,请确保data/stocks_1m/目录存在分钟数据") + return self.data_context.stock_data_dict_1m + return self.data_context.stock_data_dict + def get_price(self, security: str | list[str], start_date: str = None, end_date: str = None, frequency: str = '1d', fields: str | list[str] = None, fq: str = None, count: int = None) -> pd.DataFrame | PtradeAPI.PanelLike: """获取历史行情数据""" - # 验证fq参数 - valid_fq = ['pre', 'post', 'dypre', None] + # 验证fq参数(get_price不支持dypre,仅get_history支持) + valid_fq = ['pre', 'post', None] if fq not in valid_fq: - raise ValueError("function get_price: invalid fq argument, valid: {}, got {} (type: {})".format( - valid_fq, fq, type(fq))) + raise ValueError(f"function get_price: invalid fq argument, valid: {valid_fq}, got {fq} (type: {type(fq)})") if isinstance(fields, str): fields_list = [fields] @@ -535,26 +527,36 @@ def get_price(self, security: str | list[str], start_date: str = None, end_date: is_single_stock = isinstance(security, str) stocks = [security] if is_single_stock else security + # 根据frequency选择数据源 + data_source = self._get_data_source(frequency) + if count is not None: end_dt = pd.Timestamp(end_date) if end_date else self.context.current_dt result = {} for stock in stocks: - if stock not in self.data_context.stock_data_dict: + if stock not in data_source: continue - stock_df = self.data_context.stock_data_dict[stock] + stock_df = data_source[stock] if not isinstance(stock_df, pd.DataFrame): continue try: - date_dict, _ = self.get_stock_date_index(stock) - current_idx = date_dict.get(end_dt) - if current_idx is None: - current_idx = stock_df.index.get_loc(end_dt) + if frequency == '1m': + # 分钟数据:直接使用index查找 + idx = stock_df.index.searchsorted(end_dt, side='right') - 1 + if idx < 0: + continue + current_idx = idx + else: + date_dict, _ = self.get_stock_date_index(stock) + current_idx = date_dict.get(end_dt) + if current_idx is None: + current_idx = stock_df.index.get_loc(end_dt) except (KeyError, IndexError): continue - # Ptrade API语义: count=N 返回截止到end_date的N天数据(包含end_date) + # Ptrade API语义: count=N 返回截止到end_date的N条数据(包含end_date) slice_df = stock_df.iloc[max(0, current_idx - count + 1):current_idx + 1] result[stock] = slice_df else: @@ -563,10 +565,10 @@ def get_price(self, security: str | list[str], start_date: str = None, end_date: result = {} for stock in stocks: - if stock not in self.data_context.stock_data_dict: + if stock not in data_source: continue - stock_df = self.data_context.stock_data_dict[stock] + stock_df = data_source[stock] if not isinstance(stock_df, pd.DataFrame): continue @@ -578,48 +580,12 @@ def get_price(self, security: str | list[str], start_date: str = None, end_date: slice_df = stock_df[mask] result[stock] = slice_df - if fq == 'pre': + # 复权处理(仅日线数据支持) + if frequency != '1m' and fq in ('pre', 'post'): for stock in list(result.keys()): stock_df = result[stock] if isinstance(stock_df, pd.DataFrame) and not stock_df.empty: - # 向量化复权计算 - if self.data_context.adj_pre_cache and stock in self.data_context.adj_pre_cache: - adj_factors = self.data_context.adj_pre_cache[stock] - # 找到stock_df中有复权因子的日期 - common_idx = stock_df.index.intersection(adj_factors.index) - if len(common_idx) > 0: - adjusted_df = stock_df.copy() - adj_a = adj_factors.loc[common_idx, 'adj_a'] - adj_b = adj_factors.loc[common_idx, 'adj_b'] - price_cols = ['open', 'high', 'low', 'close'] - for col in price_cols: - if col in adjusted_df.columns: - # 前复权公式: 前复权价 = (未复权价 - adj_b) / adj_a - adjusted_df.loc[common_idx, col] = ( - adjusted_df.loc[common_idx, col] - adj_b - ) / adj_a - result[stock] = adjusted_df - - if fq == 'post': - for stock in list(result.keys()): - stock_df = result[stock] - if isinstance(stock_df, pd.DataFrame) and not stock_df.empty: - # 向量化复权计算 - if self.data_context.adj_post_cache and stock in self.data_context.adj_post_cache: - adj_factors = self.data_context.adj_post_cache[stock] - common_idx = stock_df.index.intersection(adj_factors.index) - if len(common_idx) > 0: - adjusted_df = stock_df.copy() - adj_a = adj_factors.loc[common_idx, 'adj_a'] - adj_b = adj_factors.loc[common_idx, 'adj_b'] - price_cols = ['open', 'high', 'low', 'close'] - for col in price_cols: - if col in adjusted_df.columns: - # 后复权公式: 后复权价 = adj_a * 未复权价 + adj_b - adjusted_df.loc[common_idx, col] = ( - adj_a * adjusted_df.loc[common_idx, col] + adj_b - ) - result[stock] = adjusted_df + result[stock] = self._apply_adj_factors(stock_df, stock, fq) if not result: return pd.DataFrame() @@ -648,8 +614,7 @@ def get_history(self, count: int, frequency: str = '1d', field: str | list[str] # 验证fq参数 valid_fq = ['pre', 'post', 'dypre', None] if fq not in valid_fq: - raise ValueError("function get_history: invalid fq argument, valid: {}, got {} (type: {})".format( - valid_fq, fq, type(fq))) + raise ValueError(f"function get_history: invalid fq argument, valid: {valid_fq}, got {fq} (type: {type(fq)})") if isinstance(field, str): fields = [field] @@ -668,20 +633,25 @@ def get_history(self, count: int, frequency: str = '1d', field: str | list[str] # 缓存键:使用frozen set避免列表顺序问题,但这里保持tuple更快 field_key = tuple(fields) if len(fields) > 1 else fields[0] - cache_key = (tuple(sorted(stocks)), count, field_key, fq, current_dt, include, is_dict) + cache_key = (tuple(sorted(stocks)), count, field_key, fq, current_dt, include, is_dict, frequency) # 检查缓存 if cache_key in self._history_cache: return self._history_cache[cache_key] + # 根据frequency选择数据源 + if frequency == '1m': + stock_data_dict = self._get_data_source(frequency) + else: + stock_data_dict = self.data_context.stock_data_dict + benchmark_data = self.data_context.benchmark_data + # 优化1: 批量预加载股票数据(减少LazyDataDict的重复加载) stock_dfs = {} - stock_data_dict = self.data_context.stock_data_dict - benchmark_data = self.data_context.benchmark_data for stock in stocks: - data_source = stock_data_dict.get(stock) - if data_source is None: + data_source = stock_data_dict.get(stock) if stock_data_dict else None + if data_source is None and frequency != '1m': data_source = benchmark_data.get(stock) if data_source is not None: stock_dfs[stock] = data_source @@ -692,18 +662,27 @@ def get_history(self, count: int, frequency: str = '1d', field: str | list[str] if not isinstance(data_source, pd.DataFrame): continue try: - date_dict, _ = self.get_stock_date_index(stock) - current_idx = date_dict.get(current_dt) - if current_idx is None: - current_idx = data_source.index.get_loc(current_dt) + if frequency == '1m': + # 分钟数据:使用searchsorted查找 + idx = data_source.index.searchsorted(current_dt, side='right') - 1 + if idx < 0: + continue + current_idx = idx + else: + date_dict, _ = self.get_stock_date_index(stock) + current_idx = date_dict.get(current_dt) + if current_idx is None: + current_idx = data_source.index.get_loc(current_dt) stock_info[stock] = (data_source, current_idx) except (KeyError, IndexError): continue # 优化3+4: 批量切片+复权(减少循环开销) result = {} - needs_adj_pre = fq == 'pre' and self.data_context.adj_pre_cache - needs_adj_post = fq == 'post' and self.data_context.adj_post_cache + # 分钟数据不支持复权 + needs_adj_pre = frequency != '1m' and fq == 'pre' and self.data_context.adj_pre_cache + needs_adj_dypre = frequency != '1m' and fq == 'dypre' and self.data_context.adj_pre_cache + needs_adj_post = frequency != '1m' and fq == 'post' and self.data_context.adj_post_cache price_fields = {'open', 'high', 'low', 'close'} # 预先构建集合,提升查找速度 for stock, (data_source, current_idx) in stock_info.items(): @@ -718,45 +697,91 @@ def get_history(self, count: int, frequency: str = '1d', field: str | list[str] if end_idx == 0 and current_idx == 0: end_idx = 1 + # 单字段不复权快路径:直接用 numpy 切片,绕过 pandas DataFrame 开销 + if not needs_adj_pre and not needs_adj_dypre and not needs_adj_post and len(fields) == 1: + field_name = fields[0] + if field_name in data_source.columns: + result[stock] = {field_name: data_source[field_name].values[start_idx:end_idx]} + continue + slice_df = data_source.iloc[start_idx:end_idx] if len(slice_df) == 0: continue # 前复权处理: 前复权价 = (未复权价 - adj_b) / adj_a - if needs_adj_pre and stock in self.data_context.adj_pre_cache: + if (needs_adj_pre or needs_adj_dypre) and stock in self.data_context.adj_pre_cache: adj_factors = self.data_context.adj_pre_cache[stock] - slice_adj = adj_factors.iloc[start_idx:end_idx] - adj_a = slice_adj['adj_a'].values - adj_b = slice_adj['adj_b'].values + # 快路径: adj_factors 与 data_source 长度一致时直接 iloc(常见情况) + if len(adj_factors) == len(data_source): + slice_adj = adj_factors.iloc[start_idx:end_idx] + adj_a = slice_adj['adj_a'].values + adj_b = slice_adj['adj_b'].values + aligned_df = slice_df + else: + # 慢路径: 索引不对齐时用 loc + common_idx = slice_df.index.intersection(adj_factors.index) + if len(common_idx) == 0: + stock_result = {f: slice_df[f].values for f in fields if f in slice_df.columns} + if stock_result: + result[stock] = stock_result + continue + adj_a = adj_factors.loc[common_idx, 'adj_a'].values + adj_b = adj_factors.loc[common_idx, 'adj_b'].values + aligned_df = slice_df.loc[common_idx] + + # dypre: 以当前日为基准 + if needs_adj_dypre: + base_date = current_dt + if base_date in adj_factors.index: + adj_a_base = adj_factors.loc[base_date, 'adj_a'] + adj_b_base = adj_factors.loc[base_date, 'adj_b'] + else: + adj_a_base = 1.0 + adj_b_base = 0.0 stock_result = {} for field_name in fields: - if field_name not in slice_df.columns: + if field_name not in aligned_df.columns: continue if field_name in price_fields: - # 前复权价 = (未复权价 - adj_b) / adj_a - stock_result[field_name] = (slice_df[field_name].values - adj_b) / adj_a + pre_price = (aligned_df[field_name].values - adj_b) / adj_a + if needs_adj_dypre: + stock_result[field_name] = np.round(pre_price * adj_a_base + adj_b_base, 2) + else: + stock_result[field_name] = np.round(pre_price, 2) else: - stock_result[field_name] = slice_df[field_name].values + stock_result[field_name] = aligned_df[field_name].values # 后复权处理: 后复权价 = adj_a * 未复权价 + adj_b elif needs_adj_post and stock in self.data_context.adj_post_cache: adj_factors = self.data_context.adj_post_cache[stock] - slice_adj = adj_factors.iloc[start_idx:end_idx] - adj_a = slice_adj['adj_a'].values - adj_b = slice_adj['adj_b'].values + # 快路径 / 慢路径 + if len(adj_factors) == len(data_source): + slice_adj = adj_factors.iloc[start_idx:end_idx] + adj_a = slice_adj['adj_a'].values + adj_b = slice_adj['adj_b'].values + aligned_df = slice_df + else: + common_idx = slice_df.index.intersection(adj_factors.index) + if len(common_idx) == 0: + stock_result = {f: slice_df[f].values for f in fields if f in slice_df.columns} + if stock_result: + result[stock] = stock_result + continue + adj_a = adj_factors.loc[common_idx, 'adj_a'].values + adj_b = adj_factors.loc[common_idx, 'adj_b'].values + aligned_df = slice_df.loc[common_idx] stock_result = {} for field_name in fields: - if field_name not in slice_df.columns: + if field_name not in aligned_df.columns: continue if field_name in price_fields: - # 后复权价 = adj_a * 未复权价 + adj_b - stock_result[field_name] = adj_a * slice_df[field_name].values + adj_b + stock_result[field_name] = adj_a * aligned_df[field_name].values + adj_b else: - stock_result[field_name] = slice_df[field_name].values + stock_result[field_name] = aligned_df[field_name].values else: # 不复权: 直接提取 stock_result = {field_name: slice_df[field_name].values @@ -894,6 +919,8 @@ def get_stock_status(self, stocks: str | list[str], query_type: str = 'ST', quer if nearest_date and query_type in self.data_context.stock_status_history[nearest_date]: status_dict = self.data_context.stock_status_history[nearest_date][query_type] is_problematic = status_dict.get(stock, False) is True + if len(self._stock_status_cache) >= self._stock_status_cache_max: + self._stock_status_cache.clear() self._stock_status_cache[cache_key] = is_problematic result[stock] = is_problematic continue @@ -902,6 +929,14 @@ def get_stock_status(self, stocks: str | list[str], query_type: str = 'ST', quer stock_name = self.data_context.stock_metadata.loc[stock, 'stock_name'] is_problematic = 'ST' in str(stock_name) + elif query_type == 'HALT': + # 用当日成交量=0判断停牌 + stock_df = self.data_context.stock_data_dict.get(stock) + if stock_df is not None: + query_date_only = query_dt.normalize() + if query_date_only in stock_df.index: + is_problematic = stock_df.loc[query_date_only, 'volume'] == 0 + elif query_type == 'DELISTING' and not self.data_context.stock_metadata.empty and stock in self.data_context.stock_metadata.index: try: de_listed_date = pd.Timestamp(self.data_context.stock_metadata.loc[stock, 'de_listed_date']) @@ -909,6 +944,8 @@ def get_stock_status(self, stocks: str | list[str], query_type: str = 'ST', quer except (KeyError, ValueError): pass + if len(self._stock_status_cache) >= self._stock_status_cache_max: + self._stock_status_cache.clear() self._stock_status_cache[cache_key] = is_problematic result[stock] = is_problematic @@ -916,22 +953,18 @@ def get_stock_status(self, stocks: str | list[str], query_type: str = 'ST', quer def get_stock_exrights(self, stock_code: str, date: str = None) -> Optional[pd.DataFrame]: """获取股票除权除息信息""" - try: - exrights_df = self.data_context.stock_data_store[f'/exrights/{stock_code}'] - - if date is not None: - query_date = pd.Timestamp(date) - if query_date in exrights_df.index: - return exrights_df.loc[[query_date]] - else: - return None - else: - return exrights_df - except KeyError: + exrights_df = self.data_context.exrights_dict.get(stock_code) + if exrights_df is None or exrights_df.empty: return None - except Exception: + + if date is not None: + query_date = pd.Timestamp(date) + if query_date in exrights_df.index: + return exrights_df.loc[[query_date]] return None + return exrights_df + # ==================== 指数/行业API ==================== def get_index_stocks(self, index_code: str, date: str = None) -> list[str]: @@ -944,13 +977,14 @@ def get_index_stocks(self, index_code: str, date: str = None) -> list[str]: # 如果未指定日期,使用回测当前日期 if date is None: - query_date = str(self.context.current_dt.date()) + query_date = self.context.current_dt.strftime('%Y%m%d') else: - query_date = date + # 统一日期格式为YYYYMMDD + query_date = date.replace('-', '') # 使用 bisect 找到小于等于 date 的最近日期 idx = bisect.bisect_right(available_dates, query_date) - + if idx > 0: # 向前查找包含该指数数据的最近日期 for i in range(idx - 1, -1, -1): @@ -958,35 +992,38 @@ def get_index_stocks(self, index_code: str, date: str = None) -> list[str]: if index_code in self.data_context.index_constituents[nearest_date]: result = self.data_context.index_constituents[nearest_date][index_code] return list(result) if hasattr(result, '__iter__') else [] - + return [] def get_industry_stocks(self, industry_code: str = None) -> dict | list[str]: - """推导行业成份股""" + """推导行业成份股(带缓存)""" if self.data_context.stock_metadata.empty: return {} if industry_code is None else [] - industries = {} - for stock_code, row in self.data_context.stock_metadata.iterrows(): - try: - blocks = json.loads(row['blocks']) - if 'HY' in blocks and blocks['HY']: - ind_code = blocks['HY'][0][0] - ind_name = blocks['HY'][0][1] - - if ind_code not in industries: - industries[ind_code] = { - 'name': ind_name, - 'stocks': [] - } - industries[ind_code]['stocks'].append(stock_code) - except (KeyError, json.JSONDecodeError, IndexError, TypeError): - pass + # 使用 DataContext._industry_index 缓存 + if self.data_context._industry_index is None: + industries = {} + for stock_code, row in self.data_context.stock_metadata.iterrows(): + try: + blocks = json.loads(row['blocks']) + if 'HY' in blocks and blocks['HY']: + ind_code = blocks['HY'][0][0] + ind_name = blocks['HY'][0][1] + + if ind_code not in industries: + industries[ind_code] = { + 'name': ind_name, + 'stocks': [] + } + industries[ind_code]['stocks'].append(stock_code) + except (KeyError, json.JSONDecodeError, IndexError, TypeError): + pass + self.data_context._industry_index = industries if industry_code is None: - return industries + return self.data_context._industry_index else: - return industries.get(industry_code, {}).get('stocks', []) + return self.data_context._industry_index.get(industry_code, {}).get('stocks', []) # ==================== 涨跌停API ==================== @@ -1036,10 +1073,10 @@ def check_limit(self, security: str | list[str], query_date: str = None) -> dict result[stock] = status continue - current_close = stock_df.iloc[idx]['close'] - current_high = stock_df.iloc[idx]['high'] - current_low = stock_df.iloc[idx]['low'] - prev_close = stock_df.iloc[idx-1]['close'] + current_close = stock_df['close'].values[idx] + current_high = stock_df['high'].values[idx] + current_low = stock_df['low'].values[idx] + prev_close = stock_df['close'].values[idx-1] if np.isnan(prev_close) or prev_close <= 0: # type: ignore result[stock] = status @@ -1051,7 +1088,7 @@ def check_limit(self, security: str | list[str], query_date: str = None) -> dict # 回测中不能使用当天收盘价判断涨停(会产生未来数据泄露) # 只检查一字涨停(开盘=最高=最低=涨停价) - current_open = stock_df.iloc[idx]['open'] + current_open = stock_df['open'].values[idx] # 涨停判断:一字涨停(无法买入) is_one_word_up_limit = ( @@ -1095,13 +1132,65 @@ def order(self, security: str, amount: int, limit_price: float = None) -> Option if amount == 0: return None - # 使用blotter创建订单 + # 获取执行价格(根据买卖方向计算滑点) + is_buy = amount > 0 + execution_price = self.order_processor.get_execution_price(security, limit_price, is_buy) + if execution_price is None: + self.log.warning(f"订单失败 {security} | 原因: 无法获取价格") + return None + + # 检查涨跌停 + limit_status = self.check_limit(security, self.context.current_dt)[security] + if not self.order_processor.check_limit_status(security, amount, limit_status): + return None + + # 买入时资金不足自动调整数量(Ptrade行为) + if amount > 0: + available_cash = self.context.portfolio._cash + cost = amount * execution_price + commission = self.order_processor.calculate_commission(amount, execution_price, is_sell=False) + total_cost = cost + commission + + if total_cost > available_cash: + # 确定最小交易单位 + min_lot = 200 if security.startswith('688') else 100 + adjusted = int(available_cash / execution_price / min_lot) * min_lot + + # 迭代调整确保含手续费后不超预算 + while adjusted >= min_lot: + test_cost = adjusted * execution_price + test_commission = self.order_processor.calculate_commission(adjusted, execution_price, is_sell=False) + if test_cost + test_commission <= available_cash: + break + adjusted -= min_lot + + if adjusted < min_lot: + self.log.warning(f"【买入失败】{security} | 原因: 现金不足") + return None + + self.log.warning(f"当前账户资金不足,调整{security}下单数量为{adjusted}") + amount = adjusted + + # 创建订单 + order_id, order = self.order_processor.create_order(security, amount, execution_price) + # 注册到blotter if self.context and self.context.blotter: - order = self.context.blotter.create_order(security, amount) - if limit_price is not None: - order.limit = limit_price - return order.id - return None + self.context.blotter.all_orders.append(order) + + if amount > 0: + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:买入{amount}股") + success = self.order_processor.execute_buy(security, amount, execution_price) + else: + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:卖出{abs(amount)}股") + success = self.order_processor.execute_sell(security, abs(amount), execution_price) + + if success: + order.status = '8' + order.filled = amount + if self.context and self.context.blotter: + self.context.blotter.filled_orders.append(order) + + return order.id if success else None @validate_lifecycle def order_target(self, security: str, amount: int, limit_price: float = None) -> Optional[str]: @@ -1128,7 +1217,7 @@ def order_target(self, security: str, amount: int, limit_price: float = None) -> is_buy = delta > 0 execution_price = self.order_processor.get_execution_price(security, limit_price, is_buy) if execution_price is None: - self.log.warning("订单失败 {} | 原因: 无法获取价格".format(security)) + self.log.warning(f"订单失败 {security} | 原因: 无法获取价格") return None # 检查涨跌停 @@ -1138,17 +1227,21 @@ def order_target(self, security: str, amount: int, limit_price: float = None) -> # 创建订单 order_id, order = self.order_processor.create_order(security, delta, execution_price) + if self.context and self.context.blotter: + self.context.blotter.all_orders.append(order) if delta > 0: - self.log.info("生成订单,订单号:{},股票代码:{},数量:买入{}股".format(order_id, security, delta)) + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:买入{delta}股") success = self.order_processor.execute_buy(security, delta, execution_price) else: - self.log.info("生成订单,订单号:{},股票代码:{},数量:卖出{}股".format(order_id, security, abs(delta))) + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:卖出{abs(delta)}股") success = self.order_processor.execute_sell(security, abs(delta), execution_price) if success: order.status = '8' order.filled = delta + if self.context and self.context.blotter: + self.context.blotter.filled_orders.append(order) return order.id if success else None @@ -1158,83 +1251,102 @@ def order_value(self, security: str, value: float, limit_price: float = None) -> Args: security: 股票代码 - value: 股票价值 + value: 股票价值,正数买入,负数卖出 limit_price: 买卖限价 Returns: 订单id或None """ + if abs(value) < 1: + return None + + is_buy = value > 0 + # 获取执行价格 - current_price = self.order_processor.get_execution_price(security, limit_price) + current_price = self.order_processor.get_execution_price(security, limit_price, is_buy) if current_price is None: self.log.warning(f"【下单失败】{security} | 原因: 获取价格数据失败") return None - # 检查涨停 + # 检查涨跌停 limit_status = self.check_limit(security, self.context.current_dt)[security] - if limit_status == 1: - self.log.warning(f"【买入失败】{security} | 原因: 涨停") + delta = 1 if is_buy else -1 + if not self.order_processor.check_limit_status(security, delta, limit_status): return None # 确定最小交易单位 min_lot = 200 if security.startswith('688') else 100 - # 先按目标value计算数量 - target_amount = int(value / current_price / min_lot) * min_lot - available_cash = self.context.portfolio._cash + if is_buy: + # 买入逻辑 + target_amount = int(value / current_price / min_lot) * min_lot + available_cash = self.context.portfolio._cash - # 如果目标数量 >= 最小单位,尝试按目标买入 - if target_amount >= min_lot: - # 检查现金是否足够(含手续费) - cost = target_amount * current_price - commission = self.order_processor.calculate_commission(target_amount, current_price, is_sell=False) - total_cost = cost + commission - - if total_cost <= available_cash: - # 现金足够,按目标数量买入 - amount = target_amount - else: - # 现金不足目标金额,自动调整到可买的最大数量 - max_affordable_amount = int(available_cash / current_price / min_lot) * min_lot + if target_amount >= min_lot: + cost = target_amount * current_price + commission = self.order_processor.calculate_commission(target_amount, current_price, is_sell=False) + total_cost = cost + commission - # 迭代调整,确保包含手续费后不超预算 - while max_affordable_amount >= min_lot: - test_cost = max_affordable_amount * current_price - test_commission = self.order_processor.calculate_commission(max_affordable_amount, current_price, is_sell=False) - test_total = test_cost + test_commission + if total_cost <= available_cash: + amount = target_amount + else: + max_affordable_amount = int(available_cash / current_price / min_lot) * min_lot - if test_total <= available_cash: - break - max_affordable_amount -= min_lot + while max_affordable_amount >= min_lot: + test_cost = max_affordable_amount * current_price + test_commission = self.order_processor.calculate_commission(max_affordable_amount, current_price, is_sell=False) + if test_cost + test_commission <= available_cash: + break + max_affordable_amount -= min_lot - if max_affordable_amount < min_lot: - self.log.warning("【买入失败】{} | 原因: 现金不足 (需要{:.2f}, 可用{:.2f})".format(security, total_cost, available_cash)) - return None + if max_affordable_amount < min_lot: + self.log.warning(f"【买入失败】{security} | 原因: 现金不足 (需要{total_cost:.2f}, 可用{available_cash:.2f})") + return None - self.log.warning("当前账户资金不足,调整{}下单数量为{}股(目标{:.2f}元,实际{:.2f}元)".format( - security, max_affordable_amount, value, max_affordable_amount * current_price)) - amount = max_affordable_amount + self.log.warning(f"当前账户资金不足,调整{security}下单数量为{max_affordable_amount}股(目标{value:.2f}元,实际{max_affordable_amount * current_price:.2f}元)") + amount = max_affordable_amount + else: + self.log.warning(f"【下单失败】{security} | 原因: 分配金额不足{min_lot}股 (分配{value:.2f}元, 价格{current_price:.2f}元, 可用现金{available_cash:.2f}元)") + return None + + order_id, order = self.order_processor.create_order(security, amount, current_price) + if self.context and self.context.blotter: + self.context.blotter.all_orders.append(order) + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:买入{amount}股") + success = self.order_processor.execute_buy(security, amount, current_price) else: - # 目标数量 < 最小单位,直接取消交易(避免资金分配失衡) - self.log.warning("【下单失败】{} | 原因: 分配金额不足{}股 (分配{:.2f}元, 价格{:.2f}元, 可用现金{:.2f}元)".format( - security, min_lot, value, current_price, available_cash)) - return None + # 卖出逻辑 + sell_value = abs(value) + target_amount = int(sell_value / current_price / min_lot) * min_lot - # 创建订单 - order_id, order = self.order_processor.create_order(security, amount, current_price) + # 检查持仓 + if security not in self.context.portfolio.positions: + self.log.warning(f"【卖出失败】{security} | 原因: 无持仓") + return None + + position = self.context.portfolio.positions[security] - self.log.info("生成订单,订单号:{},股票代码:{},数量:买入{}股".format(order_id, security, amount)) + # 如果目标卖出量 >= 持仓量,卖出全部(不受min_lot限制) + if target_amount >= position.amount: + target_amount = position.amount + elif target_amount < min_lot: + self.log.warning(f"【下单失败】{security} | 原因: 卖出金额不足{min_lot}股 (金额{sell_value:.2f}元, 价格{current_price:.2f}元)") + return None - # 执行订单 - success = self.order_processor.execute_buy(security, amount, current_price) + order_id, order = self.order_processor.create_order(security, -target_amount, current_price) + if self.context and self.context.blotter: + self.context.blotter.all_orders.append(order) + self.log.info(f"生成订单,订单号:{order_id},股票代码:{security},数量:卖出{target_amount}股") + success = self.order_processor.execute_sell(security, target_amount, current_price) if success: order.status = '8' - order.filled = amount + order.filled = order.amount + if self.context and self.context.blotter: + self.context.blotter.filled_orders.append(order) return order.id if success else None - @validate_lifecycle def order_target_value(self, security: str, value: float, limit_price: float = None) -> Optional[str]: """调整股票持仓市值到目标价值 @@ -1246,20 +1358,31 @@ def order_target_value(self, security: str, value: float, limit_price: float = N Returns: 订单id或None """ - # 获取当前持仓市值 - current_value = 0.0 + # 获取执行价格 + current_amount = 0 if security in self.context.portfolio.positions: - position = self.context.portfolio.positions[security] - current_value = position.amount * position.last_sale_price + current_amount = self.context.portfolio.positions[security].amount + + # 获取价格计算目标数量 + is_buy = value > current_amount * (self.context.portfolio.positions[security].last_sale_price if current_amount > 0 else 0) + execution_price = self.order_processor.get_execution_price(security, limit_price, is_buy) + if execution_price is None: + self.log.warning(f"订单失败 {security} | 原因: 无法获取价格") + return None - # 计算需要调整的价值 - delta_value = value - current_value + # 计算目标持仓数量 + min_lot = 200 if security.startswith('688') else 100 + if value <= 0: + target_amount = 0 + else: + target_amount = int(value / execution_price / min_lot) * min_lot - if abs(delta_value) < 1: # 价值差异小于1元,不调整 + delta = target_amount - current_amount + if delta == 0: return None - # 使用 order_value 下单 - return self.order_value(security, delta_value, limit_price) + # 委托给 order_target 按数量交易 + return self.order_target(security, target_amount, limit_price) def get_open_orders(self) -> list: """获取未成交订单""" @@ -1279,10 +1402,11 @@ def get_orders(self, security: str = None) -> list: if not self.context or not self.context.blotter: return [] + all_orders = self.context.blotter.all_orders if security is None: - return self.context.blotter.open_orders + return all_orders else: - return [o for o in self.context.blotter.open_orders if o.symbol == security] + return [o for o in all_orders if o.symbol == security] def get_order(self, order_id: str) -> Optional[Any]: """获取指定订单 @@ -1296,20 +1420,27 @@ def get_order(self, order_id: str) -> Optional[Any]: if not self.context or not self.context.blotter: return None - for order in self.context.blotter.open_orders: + for order in self.context.blotter.all_orders: if order.id == order_id: return order return None - def get_trades(self) -> list: + def get_trades(self, security: str = None) -> list: """获取当日成交订单 + Args: + security: 股票代码,None表示获取所有成交 + Returns: 成交订单列表 """ - # 回测中所有已成交订单都会从open_orders移除,需要单独记录 - # 这里简化实现,返回空列表 - return [] + if not self.context or not self.context.blotter: + return [] + + filled = getattr(self.context.blotter, 'filled_orders', []) + if security is None: + return filled + return [o for o in filled if o.symbol == security] def get_position(self, security: str) -> Optional[Position]: """获取持仓信息 @@ -1324,6 +1455,24 @@ def get_position(self, security: str) -> Optional[Position]: return self.context.portfolio.positions.get(security) return None + def get_positions(self, security_list: list[str] = None) -> dict[str, Position]: + """获取多支股票持仓信息 + + Args: + security_list: 股票代码列表,None表示获取所有持仓 + + Returns: + dict: {stock: Position对象} + """ + if not self.context or not self.context.portfolio: + return {} + + positions = self.context.portfolio.positions + if security_list is None: + return positions.copy() + + return {s: positions[s] for s in security_list if s in positions} + def cancel_order(self, order: Any) -> bool: """取消订单""" if self.context and self.context.blotter: @@ -1461,9 +1610,47 @@ def run_daily(self, context: Any, func: Callable, time: str = '9:31') -> None: _ = (context, func, time) # 回测中不执行 pass - # ==================== 技术指标API ==================== + @validate_lifecycle + def set_parameters(self, params: dict) -> None: + """设置策略配置参数 + + Args: + params: dict,策略参数字典 + """ + if not hasattr(self.context, 'params'): + self.context.params = {} + self.context.params.update(params) @validate_lifecycle + def convert_position_from_csv(self, file_path: str) -> list[dict]: + """从CSV文件获取设置底仓的参数列表 + + Args: + file_path: CSV文件路径 + + Returns: + list: 持仓列表,格式为 [{'security': 股票代码, 'amount': 数量, 'cost_basis': 成本价}, ...] + """ + df = pd.read_csv(file_path) + positions = [] + for _, row in df.iterrows(): + positions.append({ + 'security': row.get('security', row.get('stock', row.get('code'))), + 'amount': int(row.get('amount', row.get('qty', 0))), + 'cost_basis': float(row.get('cost_basis', row.get('cost', row.get('price', 0)))) + }) + return positions + + def get_user_name(self) -> str: + """获取登录终端的资金账号 + + Returns: + str: 资金账号(回测返回模拟账号) + """ + return 'backtest_user' + + # ==================== 技术指标API ==================== + def get_MACD(self, close: np.ndarray, short: int = 12, long: int = 26, m: int = 9) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """计算MACD指标(异同移动平均线) @@ -1492,7 +1679,6 @@ def get_MACD(self, close: np.ndarray, short: int = 12, long: int = 26, m: int = return dif, dea, macd - @validate_lifecycle def get_KDJ(self, high: np.ndarray, low: np.ndarray, close: np.ndarray, n: int = 9, m1: int = 3, m2: int = 3) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """计算KDJ指标(随机指标) @@ -1537,7 +1723,6 @@ def get_KDJ(self, high: np.ndarray, low: np.ndarray, close: np.ndarray, return k, d, j - @validate_lifecycle def get_RSI(self, close: np.ndarray, n: int = 6) -> np.ndarray: """计算RSI指标(相对强弱指标) @@ -1561,7 +1746,6 @@ def get_RSI(self, close: np.ndarray, n: int = 6) -> np.ndarray: return rsi - @validate_lifecycle def get_CCI(self, high: np.ndarray, low: np.ndarray, close: np.ndarray, n: int = 14) -> np.ndarray: """计算CCI指标(顺势指标) diff --git a/src/simtradelab/ptrade/cache_manager.py b/src/simtradelab/ptrade/cache_manager.py index 51d2305..3f968c2 100644 --- a/src/simtradelab/ptrade/cache_manager.py +++ b/src/simtradelab/ptrade/cache_manager.py @@ -129,7 +129,7 @@ def __init__(self): def get_namespace(self, name: str) -> CacheNamespace: """获取缓存命名空间""" if name not in self._namespaces: - raise ValueError("未知的缓存命名空间: {}".format(name)) + raise ValueError(f"未知的缓存命名空间: {name}") return self._namespaces[name] def get(self, namespace: str, key: Any) -> Optional[Any]: @@ -167,4 +167,3 @@ def clear_daily_cache(self, current_date: Optional[datetime] = None) -> None: # 全局单例实例 cache_manager = UnifiedCacheManager() - diff --git a/src/simtradelab/ptrade/config_manager.py b/src/simtradelab/ptrade/config_manager.py index f3997ac..1f90646 100644 --- a/src/simtradelab/ptrade/config_manager.py +++ b/src/simtradelab/ptrade/config_manager.py @@ -58,6 +58,16 @@ class TradingConfig(BaseModel): default="STOCK", description="佣金类型" ) + transfer_fee_rate: float = Field( + default=0.0000487, + ge=0, + description="经手费率(万分之0.487,证监会规定)" + ) + stamp_tax_rate: float = Field( + default=0.001, + ge=0, + description="印花税率(千分之一,卖出时收取)" + ) model_config = {"frozen": True} # 配置不可变,确保线程安全 diff --git a/src/simtradelab/ptrade/context.py b/src/simtradelab/ptrade/context.py index 6190c52..1af0aa1 100644 --- a/src/simtradelab/ptrade/context.py +++ b/src/simtradelab/ptrade/context.py @@ -18,14 +18,13 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Optional -from .lifecycle_controller import LifecycleController, LifecyclePhase +from .lifecycle_controller import LifecycleController from .config_manager import config from simtradelab.ptrade.object import ( Blotter, Portfolio, - Position ) class PTradeMode(Enum): @@ -37,48 +36,6 @@ class PTradeMode(Enum): MARGIN_TRADING = "margin_trading" # 融资融券交易模式 -class StrategyLifecycleManager: - """策略生命周期管理器 - - 负责管理策略的生命周期函数调用和阶段转换 - """ - - def __init__(self, lifecycle_controller: LifecycleController): - self._lifecycle_controller = lifecycle_controller - self._strategy_functions: dict[str, Callable] = {} - - def register_strategy_function(self, phase: str, func: Callable) -> None: - """注册策略生命周期函数""" - self._strategy_functions[phase] = func - - def execute_lifecycle_phase( - self, phase: str, context: Context, data: Optional[Any] = None - ) -> Any: - """执行指定的生命周期阶段""" - # 设置当前阶段 - phase_enum = LifecyclePhase(phase) - self._lifecycle_controller.set_phase(phase_enum) - - # 获取并执行策略函数 - strategy_func = self._strategy_functions.get(phase) - if strategy_func: - if phase in ["initialize"]: - return strategy_func(context) - elif phase in [ - "handle_data", - "before_trading_start", - "after_trading_end", - "tick_data", - ]: - return strategy_func(context, data) - elif phase in ["on_order_response"]: - return strategy_func(context, data) # data是Order对象 - elif phase in ["on_trade_response"]: - return strategy_func(context, data) # data是Trade对象 - - return None - - @dataclass class Context: """PTrade策略上下文对象 - 完全符合PTrade规范 @@ -101,27 +58,20 @@ class Context: universe: list[str] = field(default_factory=list) # 股票池 benchmark: Optional[str] = None # 基准 current_dt: Optional[datetime] = None # 当前时间 + frequency: str = '1d' # 回测频率 '1d'日线 '1m'分钟线 - # === 内部配置属性 === - _parameters: dict[str, Any] = field(default_factory=dict) # 策略参数 - _yesterday_position: dict[str, Any] = field(default_factory=dict) # 底仓 - - # 配置通过全局config管理,不再通过context转发 - # 策略代码应使用: from simtradelab.ptrade.config_manager import config - - # === 生命周期管理属性 === + # === 生命周期管理 === _lifecycle_controller: Optional[LifecycleController] = None - _lifecycle_manager: Optional[StrategyLifecycleManager] = None - - # === 调度任务属性 === - _daily_tasks: list[dict[str, Any]] = field(default_factory=list) # 日级任务 - _interval_tasks: list[dict[str, Any]] = field(default_factory=list) # 间隔任务 def __post_init__(self) -> None: """初始化后处理""" # === 初始化g全局对象 === self.g = types.SimpleNamespace() # 全局变量容器 + # === 每日买卖金额累计(由OrderProcessor写入) === + self._daily_buy_total = 0.0 + self._daily_sell_total = 0.0 + # === 设置时间 === if self.current_dt is None: self.current_dt = datetime.now() @@ -130,153 +80,12 @@ def __post_init__(self) -> None: if self.blotter is None: self.blotter = Blotter(self.current_dt) - # === 初始化生命周期管理 === + # === 初始化生命周期控制器 === if self._lifecycle_controller is None: self._lifecycle_controller = LifecycleController(self.mode.value) - if self._lifecycle_manager is None: - self._lifecycle_manager = StrategyLifecycleManager( - self._lifecycle_controller - ) - self.blotter.current_dt = self.current_dt - # ========================================== - # 策略生命周期函数注册接口 - # ========================================== - - def register_initialize(self, func: Callable[[Context], None]) -> None: - """注册initialize函数""" - self._lifecycle_manager.register_strategy_function("initialize", func) - - def register_handle_data( - self, func: Callable[[Context, Any], None] - ) -> None: - """注册handle_data函数""" - self._lifecycle_manager.register_strategy_function("handle_data", func) - - def register_before_trading_start( - self, func: Callable[[Context, Any], None] - ) -> None: - """注册before_trading_start函数""" - self._lifecycle_manager.register_strategy_function("before_trading_start", func) - - def register_after_trading_end( - self, func: Callable[[Context, Any], None] - ) -> None: - """注册after_trading_end函数""" - self._lifecycle_manager.register_strategy_function("after_trading_end", func) - - def register_tick_data(self, func: Callable[[Context, Any], None]) -> None: - """注册tick_data函数""" - self._lifecycle_manager.register_strategy_function("tick_data", func) - - def register_on_order_response( - self, func: Callable[[Context, Any], None] - ) -> None: - """注册on_order_response函数""" - self._lifecycle_manager.register_strategy_function("on_order_response", func) - - def register_on_trade_response( - self, func: Callable[[Context, Any], None] - ) -> None: - """注册on_trade_response函数""" - self._lifecycle_manager.register_strategy_function("on_trade_response", func) - - # ========================================== - # 策略生命周期执行接口 - # ========================================== - - def execute_initialize(self) -> None: - """执行初始化阶段""" - result = self._lifecycle_manager.execute_lifecycle_phase("initialize", self) - self.initialized = True - return result - - def execute_handle_data(self, data: Any) -> None: - """执行主策略逻辑阶段""" - return self._lifecycle_manager.execute_lifecycle_phase( - "handle_data", self, data - ) - - def execute_before_trading_start(self, data: Any) -> None: - """执行盘前处理阶段""" - return self._lifecycle_manager.execute_lifecycle_phase( - "before_trading_start", self, data - ) - - def execute_after_trading_end(self, data: Any) -> None: - """执行盘后处理阶段""" - return self._lifecycle_manager.execute_lifecycle_phase( - "after_trading_end", self, data - ) - - def execute_tick_data(self, data: Any) -> None: - """执行tick数据处理阶段""" - return self._lifecycle_manager.execute_lifecycle_phase("tick_data", self, data) - - def execute_on_order_response(self, order: Any) -> None: - """执行委托回报处理阶段""" - return self._lifecycle_manager.execute_lifecycle_phase( - "on_order_response", self, order - ) - - def execute_on_trade_response(self, trade: Any) -> None: - """执行成交回报处理阶段""" - return self._lifecycle_manager.execute_lifecycle_phase( - "on_trade_response", self, trade - ) - - # ========================================== - # 策略配置管理接口 - # ========================================== - - def set_universe(self, securities: list[str]) -> None: - """设置股票池""" - self.universe = securities - - def set_benchmark(self, benchmark: str) -> None: - """设置基准""" - self.benchmark = benchmark - - def set_commission(self, commission: float) -> None: - """设置佣金费率""" - config.update_trading_config(commission_ratio=commission) - - def set_slippage(self, slippage: float) -> None: - """设置滑点""" - config.update_trading_config(slippage=slippage) - - def set_volume_ratio(self, ratio: float) -> None: - """设置成交比例""" - config.update_trading_config(volume_ratio=ratio) - - def set_limit_mode(self, mode: str) -> None: - """设置成交限制模式""" - config.update_trading_config(limit_mode=mode) - - def set_yesterday_position(self, positions: dict[str, Any]) -> None: - """设置底仓""" - self._yesterday_position = positions - - def set_parameters(self, params: dict[str, Any]) -> None: - """设置策略参数""" - self._parameters.update(params) - - # ========================================== - # 调度任务管理接口 - # ========================================== - - def run_daily(self, func: Callable, time: str) -> None: - """注册日级调度任务""" - task = {"function": func, "time": time, "type": "daily"} - self._daily_tasks.append(task) - - def run_interval(self, func: Callable, interval: int) -> None: - """注册间隔调度任务""" - task = {"function": func, "interval": interval, "type": "interval"} - self._interval_tasks.append(task) - # ========================================== # 生命周期状态查询接口 # ========================================== @@ -291,11 +100,8 @@ def get_lifecycle_statistics(self) -> dict[str, Any]: def is_api_allowed(self, api_name: str) -> bool: """检查API是否在当前阶段允许调用""" - try: - self._lifecycle_controller.validate_api_call(api_name) - return True - except Exception: - return False + result = self._lifecycle_controller.validate_api_call(api_name) + return result.is_valid # ========================================== # 工具方法 @@ -309,16 +115,10 @@ def record(self, name: str, value: Any) -> None: """记录收益曲线值""" self.recorded_vars[name] = value - def get_position(self, security: str) -> Optional[Position]: - """获取持仓信息""" - return self.portfolio.positions.get(security) - def reset_for_new_strategy(self) -> None: """为新策略重置上下文状态""" self.initialized = False self.recorded_vars.clear() - self._daily_tasks.clear() - self._interval_tasks.clear() self._lifecycle_controller.reset() # 重置blotter @@ -349,8 +149,8 @@ def create_backtest_context( context = Context( portfolio=portfolio, mode=PTradeMode.BACKTEST, capital_base=capital_base ) - context.set_commission(commission_rate) - context.set_slippage(slippage_rate) + config.update_trading_config(commission_ratio=commission_rate) + config.update_trading_config(slippage=slippage_rate) return context diff --git a/src/simtradelab/ptrade/data_context.py b/src/simtradelab/ptrade/data_context.py index cd74e01..c5aaa88 100644 --- a/src/simtradelab/ptrade/data_context.py +++ b/src/simtradelab/ptrade/data_context.py @@ -31,7 +31,8 @@ def __init__( adj_pre_cache, adj_post_cache=None, dividend_cache=None, - trade_days=None + trade_days=None, + stock_data_dict_1m=None ): """初始化数据上下文 @@ -48,6 +49,7 @@ def __init__( adj_post_cache: 后复权因子缓存 dividend_cache: 分红事件缓存 trade_days: 交易日历(DatetimeIndex) + stock_data_dict_1m: 分钟数据字典(LazyDataDict) """ self.stock_data_dict = stock_data_dict self.valuation_dict = valuation_dict @@ -61,6 +63,7 @@ def __init__( self.adj_post_cache = adj_post_cache self.dividend_cache = dividend_cache if dividend_cache is not None else {} self.trade_days = trade_days + self.stock_data_dict_1m = stock_data_dict_1m # 预解析 stock_metadata 日期列为 Timestamp(优化 get_Ashares 性能) if stock_metadata is not None and not stock_metadata.empty: diff --git a/src/simtradelab/ptrade/object.py b/src/simtradelab/ptrade/object.py index 2ef825b..274d5f3 100644 --- a/src/simtradelab/ptrade/object.py +++ b/src/simtradelab/ptrade/object.py @@ -14,23 +14,35 @@ from __future__ import annotations -from collections import OrderedDict import bisect -import pandas as pd -import numpy as np +from collections import OrderedDict +from datetime import datetime from functools import wraps +from typing import Any, Optional + +import numpy as np +import pandas as pd from joblib import Parallel, delayed -from tqdm import tqdm from pydantic import BaseModel, Field -from typing import Optional, Any -from datetime import datetime +from tqdm import tqdm from ..utils.performance_config import get_performance_config from .cache_manager import cache_manager -from .config_manager import config from .lifecycle_controller import LifecyclePhase +def _get_load_map(): + """获取数据类型到加载函数的映射(延迟导入避免循环依赖)""" + from . import storage + return { + 'stock': storage.load_stock, + 'stock_1m': storage.load_stock_1m, + 'valuation': storage.load_valuation, + 'fundamentals': storage.load_fundamentals, + 'exrights': lambda data_dir, k: storage.load_exrights(data_dir, k).get('exrights_events', pd.DataFrame()) + } + + # ==================== 多进程worker函数 ==================== def _load_data_chunk(data_dir, data_type, keys_chunk) -> dict[str, Any]: """多进程worker:加载一批数据 @@ -43,16 +55,7 @@ def _load_data_chunk(data_dir, data_type, keys_chunk) -> dict[str, Any]: Returns: dict: {key: dataframe} """ - from . import storage - - load_map = { - 'stock': storage.load_stock, - 'valuation': storage.load_valuation, - 'fundamentals': storage.load_fundamentals, - 'exrights': lambda data_dir, k: storage.load_exrights(data_dir, k).get('exrights_events', pd.DataFrame()) - } - - load_func = load_map[data_type] + load_func = _get_load_map()[data_type] result: dict[str, Any] = {} for key in keys_chunk: @@ -99,20 +102,14 @@ def __init__(self, data_dir, data_type, all_keys_list, max_cache_size=6000, prel preload: 是否预加载所有数据 use_multiprocessing: 是否使用多进程加载 """ - from . import storage - self.data_dir = data_dir self.data_type = data_type - # 数据类型到加载方法的映射 - self._load_map = { - 'stock': storage.load_stock, - 'valuation': storage.load_valuation, - 'fundamentals': storage.load_fundamentals, - 'exrights': lambda data_dir, k: storage.load_exrights(data_dir, k).get('exrights_events', pd.DataFrame()) - } + # 使用公共加载映射 + self._load_map = _get_load_map() self._cache = OrderedDict() # 使用OrderedDict实现LRU self._all_keys = all_keys_list + self._all_keys_set = set(all_keys_list) # O(1) 查找 self._max_cache_size = max_cache_size # 最大缓存数量 self._preload = preload self._access_count = 0 # 访问计数器 @@ -156,12 +153,12 @@ def __init__(self, data_dir, data_type, all_keys_list, max_cache_size=6000, prel for key in tqdm(all_keys_list, desc=' 加载', ncols=80, ascii=True, bar_format='{desc}: {percentage:3.0f}%|{bar}| {n:4d}/{total:4d} [{elapsed}<{remaining}]'): try: - self._cache[key] = load_func(self.data_path, key) + self._cache[key] = load_func(self.data_dir, key) except KeyError: pass def __contains__(self, key): - return key in self._all_keys + return key in self._all_keys_set def __getitem__(self, key): if key in self._cache: @@ -271,7 +268,7 @@ def _ensure_data_loaded(self): def _load_data(self): """加载股票当日数据并应用前复权""" if self._current_idx is None or self._stock_df is None: - raise ValueError("股票 {} 在 {} 数据加载失败".format(self.stock, self.current_date)) + raise ValueError(f"股票 {self.stock} 在 {self.current_date} 数据加载失败") row = self._stock_df.iloc[self._current_idx] data = { @@ -420,29 +417,13 @@ def __getitem__(self, stock): return stock_data - -# class Context: -# """模拟context对象""" -# def __init__(self, current_dt, bt_ctx=None): -# self.current_dt = current_dt -# self.previous_date = (current_dt - timedelta(days=1)).date() -# self.portfolio = Portfolio(bt_ctx, self) -# self.blotter = Blotter(current_dt, bt_ctx) -# # 回测配置 -# self.commission_ratio = 0.0003 -# self.min_commission = 5.0 -# self.commission_type = 'STOCK' -# self.slippage = 0.0 -# self.fixed_slippage = 0.0 -# self.limit_mode = 'LIMITED' -# self.volume_ratio = 0.25 -# self.benchmark = '000300.SS' - class Blotter: """模拟blotter对象""" def __init__(self, current_dt, bt_ctx=None): self.current_dt = current_dt self.open_orders = [] + self.all_orders = [] + self.filled_orders = [] self._order_id_counter = 0 self._bt_ctx = bt_ctx @@ -457,6 +438,7 @@ def create_order(self, stock, amount): limit=None ) self.open_orders.append(order) + self.all_orders.append(order) return order def cancel_order(self, order): @@ -467,125 +449,6 @@ def cancel_order(self, order): return True return False - def process_orders(self, portfolio, current_dt): - """处理未成交订单(使用当日收盘价成交)优化版:批量预加载""" - executed_orders = [] - - if not self.open_orders: - return executed_orders - - # 批量预加载:收集所有需要的股票数据 - stock_data_cache = {} - for order in self.open_orders: - if order.symbol not in stock_data_cache and self._bt_ctx and self._bt_ctx.stock_data_dict: - stock_df = self._bt_ctx.stock_data_dict.get(order.symbol) - if stock_df is None or not isinstance(stock_df, pd.DataFrame): - continue - - if self._bt_ctx.get_stock_date_index: - date_dict, _ = self._bt_ctx.get_stock_date_index(order.symbol) - idx = date_dict.get(current_dt) - else: - idx = stock_df.index.get_loc(current_dt) if current_dt in stock_df.index else None - - if idx is not None: - stock_data_cache[order.symbol] = { - 'df': stock_df, - 'idx': idx, - 'close': stock_df.iloc[idx]['close'], - 'volume': stock_df.iloc[idx]['volume'] - } - - # 处理订单 - for order in self.open_orders[:]: - # 使用缓存获取当日收盘价 - execution_price = None - if order.symbol in stock_data_cache: - execution_price = stock_data_cache[order.symbol]['close'] - - if execution_price is None or np.isnan(execution_price) or execution_price <= 0: - continue - - # 检查成交量限制(LIMIT模式) - actual_amount = order.amount - if config.trading.limit_mode == 'LIMIT': - if order.symbol in stock_data_cache: - daily_volume = stock_data_cache[order.symbol]['volume'] - # 应用成交比例限制 - volume_ratio = config.trading.volume_ratio - max_allowed = int(daily_volume * volume_ratio) - - if abs(order.amount) > max_allowed: - if max_allowed > 0: - # 部分成交 - actual_amount = max_allowed if order.amount > 0 else -max_allowed - if self._bt_ctx.log: - self._bt_ctx.log.warning( - f"【订单部分成交】{order.symbol} | 委托量:{abs(order.amount)}, 成交量:{abs(actual_amount)} (成交比例限制:{volume_ratio})" - ) - else: - if self._bt_ctx.log: - self._bt_ctx.log.warning( - f"【订单失败】{order.symbol} | 原因: 当日成交量为0或不足" - ) - self.open_orders.remove(order) - order.status = 'failed' - continue - - # 检查涨跌停限制 - if self._bt_ctx and self._bt_ctx.check_limit: - limit_status = self._bt_ctx.check_limit(order.symbol, current_dt)[order.symbol] - if order.amount > 0 and limit_status == 1: - if self._bt_ctx.log: - self._bt_ctx.log.warning(f"【订单失败】{order.symbol} | 原因: 涨停买不进") - self.open_orders.remove(order) - order.status = 'failed' - continue - elif order.amount < 0 and limit_status == -1: - if self._bt_ctx.log: - self._bt_ctx.log.warning(f"【订单失败】{order.symbol} | 原因: 跌停卖不出") - self.open_orders.remove(order) - order.status = 'failed' - continue - - # 执行订单 - if actual_amount > 0: - # 买入 - cost = actual_amount * execution_price - if cost <= portfolio._cash: - portfolio._cash -= cost - portfolio.add_position(order.symbol, actual_amount, execution_price, current_dt) - order.status = 'filled' - order.filled = actual_amount - executed_orders.append(order) - self.open_orders.remove(order) - elif actual_amount < 0: - # 卖出 - if order.symbol in portfolio.positions: - position = portfolio.positions[order.symbol] - sell_qty = position.amount - - # 减仓/清仓(含FIFO分红税调整) - portfolio.remove_position(order.symbol, sell_qty, current_dt) - - # 卖出收入到账 - sell_revenue = sell_qty * execution_price - portfolio._cash += sell_revenue - - # 更新价格(仅在未清仓时) - if order.symbol in portfolio.positions: - position = portfolio.positions[order.symbol] - position.last_sale_price = execution_price - position.market_value = position.amount * execution_price - - order.status = 'filled' - order.filled = actual_amount - executed_orders.append(order) - - self.open_orders.remove(order) - - return executed_orders - class Order(BaseModel): """订单对象""" id: int | str = Field(..., description="订单号(支持整数或UUID字符串)") @@ -653,7 +516,7 @@ def remove_position(self, stock, amount, sell_date): # 边界检查:卖出数量不能超过持仓 if amount > position.amount: raise ValueError( - '卖出数量 {} 超过持仓 {}: {}'.format(amount, position.amount, stock) + f'卖出数量 {amount} 超过持仓 {position.amount}: {stock}' ) # FIFO计算税务调整 @@ -681,47 +544,11 @@ def add_dividend(self, stock, dividend_per_share): lot['dividends_total'] = lot.get('dividends_total', 0.0) + lot_div def _calculate_dividend_tax(self, stock, amount, sell_date): - """计算分红税调整(FIFO)""" - if stock not in self._position_lots: - return 0.0 - - lots = self._position_lots[stock] - remaining = amount - tax_adjustment = 0.0 - i = 0 + """计算分红税调整(FIFO) - while i < len(lots) and remaining > 0: - lot = lots[i] - holding_days = (sell_date - lot['date']).days - - # 真实税率 - if holding_days <= 30: - actual_rate = 0.20 - elif holding_days <= 365: - actual_rate = 0.10 - else: - actual_rate = 0.0 - - # 本批次卖出数量 - sell_qty = min(remaining, lot['amount']) - ratio = sell_qty / lot['amount'] - - # 优先使用缓存总和 - lot_div_total = lot.get('dividends_total', sum(lot['dividends'])) - tax_adjustment += lot_div_total * ratio * (actual_rate - 0.20) - - # 扣减批次 - if lot['amount'] <= remaining: - remaining -= lot['amount'] - lots.pop(i) - else: - lot['amount'] -= remaining - # 更新剩余部分的分红总额 - lot['dividends_total'] = lot_div_total * (1.0 - ratio) - remaining = 0 - i += 1 - - return tax_adjustment + Ptrade行为:分红时预扣20%,卖出时不做税务调整 + """ + return 0.0 @property def cash(self): @@ -815,8 +642,3 @@ def __init__(self, stock: str, amount: float, cost_basis: float): self.market_value = amount * cost_basis - -class Global: - """模拟全局变量g(策略可用于存储自定义数据)""" - pass - diff --git a/src/simtradelab/ptrade/order_processor.py b/src/simtradelab/ptrade/order_processor.py index 66637bf..bca5b24 100644 --- a/src/simtradelab/ptrade/order_processor.py +++ b/src/simtradelab/ptrade/order_processor.py @@ -60,18 +60,34 @@ def get_execution_price(self, stock: str, limit_price: Optional[float] = None, i if limit_price is not None: base_price = limit_price else: - if stock not in self.data_context.stock_data_dict: + # 根据frequency选择数据源 + frequency = getattr(self.context, 'frequency', '1d') + if frequency == '1m' and self.data_context.stock_data_dict_1m is not None: + data_source = self.data_context.stock_data_dict_1m + else: + data_source = self.data_context.stock_data_dict + + if stock not in data_source: return None - stock_df = self.data_context.stock_data_dict[stock] + stock_df = data_source[stock] if not isinstance(stock_df, pd.DataFrame): return None try: - date_dict, _ = self.get_stock_date_index(stock) - idx = date_dict.get(self.context.current_dt) - if idx is None: - idx = stock_df.index.get_loc(self.context.current_dt) + current_dt = self.context.current_dt + if frequency == '1m': + # 分钟数据:使用searchsorted查找最近的时间点 + idx = stock_df.index.searchsorted(current_dt, side='right') - 1 + if idx < 0: + return None + else: + # 日线数据:使用date_dict查找 + date_dict, _ = self.get_stock_date_index(stock) + idx = date_dict.get(current_dt) + if idx is None: + idx = stock_df.index.get_loc(current_dt) + price = stock_df.iloc[idx]['close'] # 转换为标量值 @@ -86,8 +102,8 @@ def get_execution_price(self, stock: str, limit_price: Optional[float] = None, i return None # 获取滑点配置 - slippage = getattr(self.context, 'slippage', config.trading.slippage) - fixed_slippage = getattr(self.context, 'fixed_slippage', config.trading.fixed_slippage) + slippage = config.trading.slippage + fixed_slippage = config.trading.fixed_slippage # 计算滑点金额 if slippage > 0: @@ -122,10 +138,10 @@ def check_limit_status(self, stock: str, delta: int, limit_status: int) -> bool: 是否可交易 """ if delta > 0 and limit_status == 1: - self.log.warning("【订单失败】{} | 原因: 涨停买不进".format(stock)) + self.log.warning(f"【订单失败】{stock} | 原因: 涨停买不进") return False elif delta < 0 and limit_status == -1: - self.log.warning("【订单失败】{} | 原因: 跌停卖不出".format(stock)) + self.log.warning(f"【订单失败】{stock} | 原因: 跌停卖不出") return False return True @@ -161,8 +177,8 @@ def calculate_commission(self, amount: int, price: float, is_sell: bool = False) Returns: 手续费总额 """ - commission_ratio = getattr(self.context, 'commission_ratio', config.trading.commission_ratio) - min_commission = getattr(self.context, 'min_commission', config.trading.min_commission) + commission_ratio = config.trading.commission_ratio + min_commission = config.trading.min_commission # 如果手续费率为0,则完全不收手续费 if commission_ratio == 0: @@ -171,16 +187,14 @@ def calculate_commission(self, amount: int, price: float, is_sell: bool = False) value = amount * price # 佣金费 broker_fee = max(value * commission_ratio, min_commission) - # 经手费率:万分之0.487 - transfer_fee = value * 0.0000487 + # 经手费 + transfer_fee = value * config.trading.transfer_fee_rate commission = broker_fee + transfer_fee # 印花税(仅卖出时收取) if is_sell: - tax_rate = getattr(self.context, 'tax_rate', 0.001) - tax = value * tax_rate - commission += tax + commission += value * config.trading.stamp_tax_rate return commission @@ -200,8 +214,7 @@ def execute_buy(self, stock: str, amount: int, price: float) -> bool: total_cost = cost + commission if total_cost > self.context.portfolio._cash: - self.log.warning("【买入失败】{} | 原因: 现金不足 (需要{:.2f}, 可用{:.2f})".format( - stock, total_cost, self.context.portfolio._cash)) + self.log.warning(f"【买入失败】{stock} | 原因: 现金不足 (需要{total_cost:.2f}, 可用{self.context.portfolio._cash:.2f})") return False self.context.portfolio._cash -= total_cost @@ -214,6 +227,9 @@ def execute_buy(self, stock: str, amount: int, price: float) -> bool: # 建仓/加仓(含批次追踪) self.context.portfolio.add_position(stock, amount, price, self.context.current_dt) + # 累计当日买入金额(gross,不含手续费) + self.context._daily_buy_total += amount * price + return True def execute_sell(self, stock: str, amount: int, price: float) -> bool: @@ -228,14 +244,13 @@ def execute_sell(self, stock: str, amount: int, price: float) -> bool: 是否成功 """ if stock not in self.context.portfolio.positions: - self.log.warning("【卖出失败】{} | 原因: 无持仓".format(stock)) + self.log.warning(f"【卖出失败】{stock} | 原因: 无持仓") return False position = self.context.portfolio.positions[stock] if position.amount < amount: - self.log.warning("【卖出失败】{} | 原因: 持仓不足 (持有{}, 尝试卖出{})".format( - stock, position.amount, amount)) + self.log.warning(f"【卖出失败】{stock} | 原因: 持仓不足 (持有{position.amount}, 尝试卖出{amount})") return False # 计算手续费 @@ -263,11 +278,14 @@ def execute_sell(self, stock: str, amount: int, price: float) -> bool: # 入账 self.context.portfolio._cash += net_revenue + # 累计当日卖出金额(gross,不含手续费) + self.context._daily_sell_total += amount * price + # 日志 if tax_adjustment > 0: - self.log.info("📊分红税 | {} | 补税{:.2f}元".format(stock, tax_adjustment)) + self.log.info(f"📊分红税 | {stock} | 补税{tax_adjustment:.2f}元") elif tax_adjustment < 0: - self.log.info("📊分红税 | {} | 退税{:.2f}元".format(stock, -tax_adjustment)) + self.log.info(f"📊分红税 | {stock} | 退税{-tax_adjustment:.2f}元") return True @@ -287,7 +305,7 @@ def process_order(self, stock: str, target_amount: int, limit_price: Optional[fl # 1. 获取执行价格 price = self.get_execution_price(stock, limit_price) if price is None: - self.log.warning("【订单失败】{} | 原因: 无法获取价格".format(stock)) + self.log.warning(f"【订单失败】{stock} | 原因: 无法获取价格") return False # 2. 计算交易数量 diff --git a/src/simtradelab/ptrade/storage.py b/src/simtradelab/ptrade/storage.py index c18ac33..2da1773 100644 --- a/src/simtradelab/ptrade/storage.py +++ b/src/simtradelab/ptrade/storage.py @@ -204,3 +204,24 @@ def list_stocks(data_dir): parquet_files = list(stocks_dir.glob('*.parquet')) return [f.stem for f in parquet_files] + + +def load_stock_1m(data_dir, symbol): + """加载分钟线数据""" + parquet_file = Path(data_dir) / 'stocks_1m' / (symbol + '.parquet') + if parquet_file.exists(): + df = pd.read_parquet(parquet_file) + if not df.empty and 'datetime' in df.columns: + df.set_index('datetime', inplace=True) + return df + return pd.DataFrame() + + +def list_stocks_1m(data_dir): + """列出所有可用的分钟数据股票代码""" + stocks_dir = Path(data_dir) / 'stocks_1m' + if not stocks_dir.exists(): + return [] + + parquet_files = list(stocks_dir.glob('*.parquet')) + return [f.stem for f in parquet_files] diff --git a/src/simtradelab/ptrade/strategy_data_analyzer.py b/src/simtradelab/ptrade/strategy_data_analyzer.py index 08cd4f3..01eb03c 100644 --- a/src/simtradelab/ptrade/strategy_data_analyzer.py +++ b/src/simtradelab/ptrade/strategy_data_analyzer.py @@ -105,7 +105,7 @@ def analyze_strategy_data_requirements(strategy_path: str) -> DataDependencies: except Exception as e: # 分析失败时返回全量依赖 - print("策略分析失败: {}, 加载全部数据".format(e)) + print(f"策略分析失败: {e}, 加载全部数据") return DataDependencies( needs_price_data=True, needs_valuation=True, @@ -123,11 +123,11 @@ def print_dependencies(deps: DataDependencies): items.append("估值") if deps.needs_fundamentals: tables = ','.join(deps.fundamental_tables) if deps.fundamental_tables else '全部' - items.append("财务({})".format(tables)) + items.append(f"财务({tables})") if deps.needs_exrights: items.append("除权") if items: - print("策略数据依赖: {}".format(' | '.join(items))) + print(f"策略数据依赖: {' | '.join(items)}") else: print("策略数据依赖: 无") diff --git a/src/simtradelab/ptrade/strategy_engine.py b/src/simtradelab/ptrade/strategy_engine.py index 00f7a02..ee1a296 100644 --- a/src/simtradelab/ptrade/strategy_engine.py +++ b/src/simtradelab/ptrade/strategy_engine.py @@ -41,8 +41,8 @@ def __init__( context: Context, api: Any, stats_collector: Any, - g: Any, log: logging.Logger, + frequency: str = '1d', ): """ 初始化策略执行引擎 @@ -51,15 +51,15 @@ def __init__( context: PTrade Context对象 api: PtradeAPI对象 stats_collector: 统计收集器 - g: Global对象 log: 日志对象 + frequency: 回测频率 '1d'日线 '1m'分钟线 """ # 核心组件(外部注入) self.context = context self.api = api self.stats_collector = stats_collector - self.g = g self.log = log + self.frequency = frequency # 获取生命周期控制器 if self.context._lifecycle_controller is None: @@ -88,7 +88,7 @@ def load_strategy_from_file(self, strategy_path: str) -> None: strategy_namespace = { '__name__': '__main__', '__file__': strategy_path, - 'g': self.g, + 'g': self.context.g, 'log': self.log, 'context': self.context, } @@ -130,45 +130,38 @@ def set_strategy_name(self, strategy_name: str) -> None: def register_initialize(self, func: Callable[[Context], None]) -> None: """注册initialize函数""" self._strategy_functions["initialize"] = func - self.context.register_initialize(func) def register_handle_data(self, func: Callable[[Context, Any], None]) -> None: """注册handle_data函数""" self._strategy_functions["handle_data"] = func - self.context.register_handle_data(func) def register_before_trading_start( self, func: Callable[[Context, Any], None] ) -> None: """注册before_trading_start函数""" self._strategy_functions["before_trading_start"] = func - self.context.register_before_trading_start(func) def register_after_trading_end( self, func: Callable[[Context, Any], None] ) -> None: """注册after_trading_end函数""" self._strategy_functions["after_trading_end"] = func - self.context.register_after_trading_end(func) def register_tick_data(self, func: Callable[[Context, Any], None]) -> None: """注册tick_data函数""" self._strategy_functions["tick_data"] = func - self.context.register_tick_data(func) def register_on_order_response( self, func: Callable[[Context, Any], None] ) -> None: """注册on_order_response函数""" self._strategy_functions["on_order_response"] = func - self.context.register_on_order_response(func) def register_on_trade_response( self, func: Callable[[Context, Any], None] ) -> None: """注册on_trade_response函数""" self._strategy_functions["on_trade_response"] = func - self.context.register_on_trade_response(func) # ========================================== # PTrade API 代理接口 @@ -209,8 +202,11 @@ def run_backtest(self, date_range) -> bool: # 1. 执行初始化 self._execute_initialize() - # 2. 执行每日循环 - success = self._run_daily_loop(date_range) + # 2. 根据frequency选择循环模式 + if self.frequency == '1m': + success = self._run_minute_loop(date_range) + else: + success = self._run_daily_loop(date_range) if success: self.log.info("Strategy execution completed successfully") @@ -227,8 +223,12 @@ def run_backtest(self, date_range) -> bool: def _execute_initialize(self) -> None: """执行初始化阶段""" + from simtradelab.ptrade.lifecycle_controller import LifecyclePhase + self.log.info("Executing initialize phase") - self.context.execute_initialize() + self.lifecycle_controller.set_phase(LifecyclePhase.INITIALIZE) + self._strategy_functions["initialize"](self.context) + self.context.initialized = True def _run_daily_loop(self, date_range) -> bool: """执行每日回测循环 @@ -243,6 +243,9 @@ def _run_daily_loop(self, date_range) -> bool: from simtradelab.ptrade.object import Data from simtradelab.ptrade.cache_manager import cache_manager + # 跨日追踪:上一交易日收盘后的组合市值(用于计算真实日盈亏) + prev_day_end_value = None + for current_date in date_range: # 更新日期上下文 self.context.current_dt = current_date @@ -259,13 +262,12 @@ def _run_daily_loop(self, date_range) -> bool: # 清理全局缓存 cache_manager.clear_daily_cache(current_date) - # 记录交易前状态 - prev_portfolio_value = self.context.portfolio.portfolio_value - prev_cash = self.context.portfolio._cash - # 收集交易前统计 self.stats_collector.collect_pre_trading(self.context, current_date) + # 处理除权除息事件(在策略执行前) + self._process_dividend_events(current_date) + # 构造data对象 data = Data(current_date, self.context.portfolio._bt_ctx) @@ -273,18 +275,111 @@ def _run_daily_loop(self, date_range) -> bool: if not self._execute_lifecycle(data): return False - # 处理分红事件(在生命周期执行完、订单成交后) + # 收集交易金额(从OrderProcessor累计的gross金额) + self.stats_collector.collect_trading_amounts(self.context) + + # 收集交易后统计(用上一交易日收盘后的组合市值计算真实日盈亏) + current_end_value = self.context.portfolio.portfolio_value + if prev_day_end_value is None: + prev_day_end_value = current_end_value # 首日无盈亏 + self.stats_collector.collect_post_trading(self.context, prev_day_end_value) + prev_day_end_value = current_end_value + + return True + + def _run_minute_loop(self, date_range) -> bool: + """执行分钟级回测循环 + + Args: + date_range: 交易日序列 + + Returns: + 是否成功完成所有交易日 + """ + from datetime import timedelta + from simtradelab.ptrade.object import Data + from simtradelab.ptrade.cache_manager import cache_manager + from simtradelab.ptrade.lifecycle_controller import LifecyclePhase + + # 跨日追踪:上一交易日收盘后的组合市值 + prev_day_end_value = None + + for current_date in date_range: + # 更新日期上下文(设为开盘时间) + self.context.current_dt = current_date + self.context.blotter.current_dt = current_date + + # 使用API获取真正的前一交易日 + prev_trade_day = self.api.get_trading_day(-1) + if prev_trade_day: + self.context.previous_date = prev_trade_day + else: + self.context.previous_date = (current_date - timedelta(days=1)).date() + + # 清理全局缓存 + cache_manager.clear_daily_cache(current_date) + + # 收集交易前统计 + self.stats_collector.collect_pre_trading(self.context, current_date) + + # 处理除权除息事件(在策略执行前) self._process_dividend_events(current_date) - # 收集交易金额 - current_cash = self.context.portfolio._cash - self.stats_collector.collect_trading_amounts(prev_cash, current_cash) + # 构造data对象 + data = Data(current_date, self.context.portfolio._bt_ctx) + + # 1. before_trading_start(每日一次,开盘前) + if not self._safe_call('before_trading_start', LifecyclePhase.BEFORE_TRADING_START, data): + return False + + # 2. handle_data(分钟级调用) + minute_bars = self._get_minute_bars(current_date) + for minute_dt in minute_bars: + self.context.current_dt = minute_dt + data = Data(minute_dt, self.context.portfolio._bt_ctx) + if not self._safe_call('handle_data', LifecyclePhase.HANDLE_DATA, data): + return False + + # 3. after_trading_end(每日一次,收盘后) + self.context.current_dt = current_date.replace(hour=15, minute=0, second=0) + data = Data(self.context.current_dt, self.context.portfolio._bt_ctx) + self._safe_call('after_trading_end', LifecyclePhase.AFTER_TRADING_END, data, allow_fail=True) + + # 收集交易金额(从OrderProcessor累计的gross金额) + self.stats_collector.collect_trading_amounts(self.context) # 收集交易后统计 - self.stats_collector.collect_post_trading(self.context, prev_portfolio_value) + current_end_value = self.context.portfolio.portfolio_value + if prev_day_end_value is None: + prev_day_end_value = current_end_value + self.stats_collector.collect_post_trading(self.context, prev_day_end_value) + prev_day_end_value = current_end_value return True + def _get_minute_bars(self, trade_date): + """生成交易日分钟时间序列 + + Args: + trade_date: 交易日 + + Returns: + 分钟时间戳列表 + """ + import pandas as pd + + # A股交易时间: 9:30-11:30, 13:00-15:00 + morning_start = trade_date.replace(hour=9, minute=30, second=0, microsecond=0) + morning_end = trade_date.replace(hour=11, minute=30, second=0, microsecond=0) + afternoon_start = trade_date.replace(hour=13, minute=0, second=0, microsecond=0) + afternoon_end = trade_date.replace(hour=15, minute=0, second=0, microsecond=0) + + # 生成分钟序列 + morning_bars = pd.date_range(morning_start, morning_end, freq='1min') + afternoon_bars = pd.date_range(afternoon_start, afternoon_end, freq='1min') + + return list(morning_bars) + list(afternoon_bars) + def _execute_lifecycle(self, data) -> bool: """执行策略生命周期方法 @@ -348,25 +443,39 @@ def _safe_call( return allow_fail def _process_dividend_events(self, current_date): - """处理分红事件 + """处理除权除息事件 Args: current_date: 当前交易日 - 分红处理逻辑: - 1. 分红到账时全额到账(不扣税) - 2. 记录每批次的分红金额 - 3. 卖出时根据持股时间(FIFO)计算并扣除分红税 + 处理逻辑: + 1. 送股/配股: 调整持仓数量 + 2. 现金分红: 到账(预扣税20%) """ try: date_str = current_date.strftime('%Y%m%d') - # 遍历所有持仓股票 for stock_code, position in self.context.portfolio.positions.items(): if position.amount <= 0: continue - # 从缓存中查找分红 + # 分红和送股都基于登记日(前一天)的持股数 + original_amount = position.amount + + # 检查除权事件(送股/配股) + exrights_df = self.api.data_context.exrights_dict.get(stock_code) + if exrights_df is not None and not exrights_df.empty: + date_key = int(date_str) if exrights_df.index.dtype in ('int64', 'int32') else current_date + if date_key in exrights_df.index: + event = exrights_df.loc[date_key] + allotted = float(event.get('allotted_ps', 0) or 0) + if allotted > 0: + new_amount = int(original_amount * (1 + allotted)) + position.amount = new_amount + position.enable_amount = new_amount + self.context.portfolio._invalidate_cache() + + # 现金分红(按登记日股数计算) if stock_code not in self.api.data_context.dividend_cache: continue @@ -374,33 +483,18 @@ def _process_dividend_events(self, current_date): if date_str not in stock_dividends: continue - # 获取税前分红金额(每股) dividend_per_share_before_tax = stock_dividends[date_str] - - # 预扣税率20%(保守估计) pre_tax_rate = 0.20 dividend_per_share_after_tax = dividend_per_share_before_tax * (1 - pre_tax_rate) - total_dividend_after_tax = dividend_per_share_after_tax * position.amount + total_dividend_after_tax = dividend_per_share_after_tax * original_amount if total_dividend_after_tax > 0: - # 税后金额到账 - old_cash = self.context.portfolio._cash self.context.portfolio._cash += total_dividend_after_tax self.context.portfolio._invalidate_cache() - - # 记录分红到批次(用于卖出时税务调整) self.context.portfolio.add_dividend(stock_code, dividend_per_share_before_tax) - self.log.info( - f"💰分红 | {stock_code} | {position.amount}股 | " - f"税前{dividend_per_share_before_tax:.4f}元/股 | 预扣税率{pre_tax_rate:.0%} | " - f"到账{total_dividend_after_tax:.2f}元 | " - f"现金: {old_cash:.2f} → {self.context.portfolio._cash:.2f}" - ) - except Exception as e: - self.log.warning(f"分红处理失败: {e}") - import traceback + self.log.warning(f"除权除息处理失败: {e}") traceback.print_exc() # ========================================== diff --git a/src/simtradelab/ptrade/strategy_validator.py b/src/simtradelab/ptrade/strategy_validator.py index 5543e85..b9ae4b2 100644 --- a/src/simtradelab/ptrade/strategy_validator.py +++ b/src/simtradelab/ptrade/strategy_validator.py @@ -44,9 +44,9 @@ def __init__(self, strategy_code: str, check_py35_compat: bool = True): try: self.tree = ast.parse(strategy_code) except SyntaxError as e: - self.errors.append("语法错误: 行 {} - {}".format(e.lineno, e.msg)) + self.errors.append(f"语法错误: 行 {e.lineno} - {e.msg}") except Exception as e: - self.errors.append("解析失败: {}".format(str(e))) + self.errors.append(f"解析失败: {str(e)}") def validate(self) -> bool: """验证策略 @@ -74,8 +74,8 @@ def validate(self) -> bool: # 检查当前阶段是否被允许 if phase.value not in allowed_phase_names: self.errors.append( - "行 {}: API '{}' 不能在 '{}' 阶段调用。" - "允许的阶段: {}".format(lineno, api_name, phase.value, allowed_phase_names) + f"行 {lineno}: API '{api_name}' 不能在 '{phase.value}' 阶段调用。" + f"允许的阶段: {allowed_phase_names}" ) # Python 3.5兼容性检查 @@ -143,11 +143,11 @@ def validate_strategy_file(strategy_path: str, check_py35_compat: bool = True, a with open(strategy_path, 'r', encoding='utf-8') as f: strategy_code = f.read() except FileNotFoundError: - return False, ["文件不存在: {}".format(strategy_path)], None + return False, [f"文件不存在: {strategy_path}"], None except PermissionError: - return False, ["无权限读取文件: {}".format(strategy_path)], None + return False, [f"无权限读取文件: {strategy_path}"], None except Exception as e: - return False, ["读取文件失败: {}".format(str(e))], None + return False, [f"读取文件失败: {str(e)}"], None # 如果需要检查兼容性且启用自动修复,先尝试修复 fixed_code = None @@ -163,7 +163,7 @@ def validate_strategy_file(strategy_path: str, check_py35_compat: bool = True, a strategy_code = fixed fixed_code = fixed except Exception as e: - return False, ["写入修复后的代码失败: {}".format(str(e))], None + return False, [f"写入修复后的代码失败: {str(e)}"], None validator = StrategyValidator(strategy_code, check_py35_compat=check_py35_compat) is_valid = validator.validate() diff --git a/src/simtradelab/service/data_server.py b/src/simtradelab/service/data_server.py index 3449cac..1ba2691 100644 --- a/src/simtradelab/service/data_server.py +++ b/src/simtradelab/service/data_server.py @@ -31,13 +31,13 @@ def __new__(cls, *args, **kwargs): cls._instance = super().__new__(cls) return cls._instance - def __init__(self, required_data=None): + def __init__(self, required_data=None, frequency='1d'): # 只初始化一次基础结构 if DataServer._initialized: print("使用已加载的数据(常驻内存)") # 如果指定了新的数据需求,动态补充加载 if required_data is not None: - self._ensure_data_loaded(required_data) + self._ensure_data_loaded(required_data, frequency) return print("=" * 70) @@ -47,9 +47,10 @@ def __init__(self, required_data=None): # 读取配置 self.data_path = global_config.data_path - print("数据路径: {}".format(self.data_path)) + print(f"数据路径: {self.data_path}") self.stock_data_dict = None + self.stock_data_dict_1m = None self.valuation_dict = None self.fundamentals_dict = None self.exrights_dict = None @@ -64,15 +65,17 @@ def __init__(self, required_data=None): # 记录已加载的数据类型 self._loaded_data_types = set() + self._frequency = frequency # 缓存keys避免重复读取 self._stock_keys_cache = None + self._stock_1m_keys_cache = None self._valuation_keys_cache = None self._fundamentals_keys_cache = None self._exrights_keys_cache = None # 加载数据 - self._load_data(required_data) + self._load_data(required_data, frequency) # 注册清理函数:进程退出时自动关闭文件 atexit.register(self._cleanup_on_exit) @@ -82,7 +85,7 @@ def __init__(self, required_data=None): def _clear_all_caches(self): """清空所有缓存""" for cache in [self.valuation_dict, self.fundamentals_dict, - self.stock_data_dict, self.exrights_dict]: + self.stock_data_dict, self.exrights_dict, self.stock_data_dict_1m]: if cache is not None: cache.clear_cache() @@ -90,11 +93,12 @@ def _cleanup_on_exit(self): """进程退出时清理资源""" pass - def _load_data(self, required_data=None): + def _load_data(self, required_data=None, frequency='1d'): """加载数据 Args: required_data: 需要加载的数据集合,None表示全部加载 + frequency: 数据频率 '1d'日线 '1m'分钟线 """ # 默认加载全部 if required_data is None: @@ -102,12 +106,14 @@ def _load_data(self, required_data=None): # 记录需要加载的数据类型 self._loaded_data_types = required_data + self._frequency = frequency from ..ptrade import storage print("正在读取数据...") # 获取股票列表 self._stock_keys_cache = storage.list_stocks(self.data_path) + self._stock_1m_keys_cache = storage.list_stocks_1m(self.data_path) self._valuation_keys_cache = self._stock_keys_cache self._fundamentals_keys_cache = self._stock_keys_cache self._exrights_keys_cache = self._stock_keys_cache @@ -162,9 +168,9 @@ def _load_data_by_types(self, required_data): """加载数据类型""" from ..ptrade import storage - # 股票价格 + # 股票价格(日线) if 'price' in required_data: - print("\n[1] 股票价格({}只)...".format(len(self._stock_keys_cache))) + print(f"\n[1] 股票价格({len(self._stock_keys_cache)}只)...") self.stock_data_dict = LazyDataDict( self.data_path, 'stock', self._stock_keys_cache, preload=True @@ -180,9 +186,26 @@ def _load_data_by_types(self, required_data): print("\n[1] 股票价格(跳过)") self.stock_data_dict = LazyDataDict(self.data_path, 'stock', [], preload=False) + # 分钟数据(按需加载) + if 'price_1m' in required_data and self._stock_1m_keys_cache: + print(f"[1.1] 分钟数据({len(self._stock_1m_keys_cache)}只)...") + self.stock_data_dict_1m = LazyDataDict( + self.data_path, 'stock_1m', self._stock_1m_keys_cache, + preload=True + ) + else: + # 延迟加载模式 + if self._stock_1m_keys_cache: + self.stock_data_dict_1m = LazyDataDict( + self.data_path, 'stock_1m', self._stock_1m_keys_cache, + preload=False + ) + else: + self.stock_data_dict_1m = None + # 估值数据 if 'valuation' in required_data: - print("[2] 估值数据({}只)...".format(len(self._valuation_keys_cache))) + print(f"[2] 估值数据({len(self._valuation_keys_cache)}只)...") self.valuation_dict = LazyDataDict( self.data_path, 'valuation', self._valuation_keys_cache, preload=True @@ -193,7 +216,7 @@ def _load_data_by_types(self, required_data): # 财务数据 if 'fundamentals' in required_data: - print("[3] 财务数据({}只,延迟加载)...".format(len(self._fundamentals_keys_cache))) + print(f"[3] 财务数据({len(self._fundamentals_keys_cache)}只,延迟加载)...") from ..ptrade.config_manager import config self.fundamentals_dict = LazyDataDict( self.data_path, 'fundamentals', self._fundamentals_keys_cache, @@ -206,7 +229,7 @@ def _load_data_by_types(self, required_data): # 除权数据 if 'exrights' in required_data: - print("[4] 除权数据({}只,延迟加载)...".format(len(self._exrights_keys_cache))) + print(f"[4] 除权数据({len(self._exrights_keys_cache)}只,延迟加载)...") from ..ptrade.config_manager import config self.exrights_dict = LazyDataDict( self.data_path, 'exrights', self._exrights_keys_cache, @@ -217,7 +240,7 @@ def _load_data_by_types(self, required_data): print("[4] 除权数据(跳过)") self.exrights_dict = LazyDataDict(self.data_path, 'exrights', [], preload=False) - print("\n已加载: {}".format(' | '.join(sorted(required_data)))) + print(f"\n已加载: {' | '.join(sorted(required_data))}") # 动态获取所有指数代码 index_codes = set() @@ -235,7 +258,7 @@ def _load_data_by_types(self, required_data): self.benchmark_data['000300.SS'] = self.stock_data_dict['000300.SS'] keys_list = list(self.benchmark_data.keys()) - print("可用基准(共 {} 个): {} ...".format(len(keys_list), ', '.join(keys_list[:5]))) + print(f"可用基准(共 {len(keys_list)} 个): {', '.join(keys_list[:5])} ...") # 加载复权缓存 if 'price' in required_data or 'exrights' in required_data: @@ -260,37 +283,43 @@ def _load_data_by_types(self, required_data): print("✓ 数据加载完成\n") - def _ensure_data_loaded(self, required_data): + def _ensure_data_loaded(self, required_data, frequency='1d'): """确保所需数据已加载,动态补充缺失的数据 Args: required_data: 需要的数据集合 + frequency: 回测频率 '1d'日线 '1m'分钟线 """ if not hasattr(self, '_loaded_data_types'): self._loaded_data_types = set() # 计算缺失的数据类型 missing = set(required_data) - self._loaded_data_types + + # 分钟回测需要加载分钟数据 + if frequency == '1m' and self.stock_data_dict_1m is None: + missing.add('price_1m') + if not missing: return - print("补充加载缺失数据: {}".format(', '.join(sorted(missing)))) + print(f"补充加载缺失数据: {', '.join(sorted(missing))}") # 使用缓存的keys加载缺失数据 if 'price' in missing and self._stock_keys_cache is not None: - print(" 加载股票价格({}只)...".format(len(self._stock_keys_cache))) + print(f" 加载股票价格({len(self._stock_keys_cache)}只)...") self.stock_data_dict = LazyDataDict( self.data_path, 'stock', self._stock_keys_cache, preload=True ) if 'valuation' in missing and self._valuation_keys_cache is not None: - print(" 加载估值数据({}只)...".format(len(self._valuation_keys_cache))) + print(f" 加载估值数据({len(self._valuation_keys_cache)}只)...") self.valuation_dict = LazyDataDict( self.data_path, 'valuation', self._valuation_keys_cache, preload=True ) if 'fundamentals' in missing and self._fundamentals_keys_cache is not None: - print(" 加载财务数据({}只,延迟加载)...".format(len(self._fundamentals_keys_cache))) + print(f" 加载财务数据({len(self._fundamentals_keys_cache)}只,延迟加载)...") from ..ptrade.config_manager import config self.fundamentals_dict = LazyDataDict( self.data_path, 'fundamentals', self._fundamentals_keys_cache, @@ -299,7 +328,7 @@ def _ensure_data_loaded(self, required_data): ) if 'exrights' in missing and self._exrights_keys_cache is not None: - print(" 加载除权数据({}只,延迟加载)...".format(len(self._exrights_keys_cache))) + print(f" 加载除权数据({len(self._exrights_keys_cache)}只,延迟加载)...") from ..ptrade.config_manager import config self.exrights_dict = LazyDataDict( self.data_path, 'exrights', self._exrights_keys_cache, @@ -307,6 +336,13 @@ def _ensure_data_loaded(self, required_data): max_cache_size=config.cache.exrights_cache_size ) + if 'price_1m' in missing and self._stock_1m_keys_cache is not None: + print(f" 加载分钟数据({len(self._stock_1m_keys_cache)}只)...") + self.stock_data_dict_1m = LazyDataDict( + self.data_path, 'stock_1m', self._stock_1m_keys_cache, + preload=True + ) + # 更新已加载记录 self._loaded_data_types.update(missing) diff --git a/src/simtradelab/utils/py35_compat_checker.py b/src/simtradelab/utils/py35_compat_checker.py index ecb187b..503c90c 100644 --- a/src/simtradelab/utils/py35_compat_checker.py +++ b/src/simtradelab/utils/py35_compat_checker.py @@ -192,7 +192,7 @@ def _check_ast_features(self): ) -def check_python35_compatibility(code: str) -> Tuple[bool, List[str]]: +def check_python35_compatibility(code: str) -> tuple[bool, list[str]]: """检查代码是否兼容Python 3.5 Args: @@ -205,7 +205,7 @@ def check_python35_compatibility(code: str) -> Tuple[bool, List[str]]: return checker.check() -def check_file_python35_compatibility(filepath: str) -> Tuple[bool, List[str]]: +def check_file_python35_compatibility(filepath: str) -> tuple[bool, list[str]]: """检查文件是否兼容Python 3.5 Args: @@ -227,7 +227,7 @@ def check_file_python35_compatibility(filepath: str) -> Tuple[bool, List[str]]: return check_python35_compatibility(code) -def check_and_fix_file(filepath: str, auto_fix: bool = True) -> Tuple[bool, List[str], str]: +def check_and_fix_file(filepath: str, auto_fix: bool = True) -> tuple[bool, list[str], str]: """检查并自动修复文件的Python 3.5兼容性问题 Args: diff --git a/typings/builtins.pyi b/typings/builtins.pyi deleted file mode 100644 index 89c0214..0000000 --- a/typings/builtins.pyi +++ /dev/null @@ -1,408 +0,0 @@ -# -*- coding: utf-8 -*- -""" -PTrade 全局 API 类型定义 -扩展 Python builtins,将 ptrade API 注入全局作用域 -""" - -from __future__ import annotations -from typing import Any, Callable, Optional -import pandas as pd - -# ==================== 基础 API ==================== - -def get_research_path() -> str: - """返回研究目录路径""" - ... - -def get_Ashares(date: Optional[str] = ...) -> list[str]: - """返回A股代码列表,支持历史查询 - - Args: - date: 查询日期,None表示当前回测日期 - """ - ... - -def get_trade_days(start_date: Optional[str] = ..., end_date: Optional[str] = ..., count: Optional[int] = ...) -> list[str]: - """获取指定范围交易日列表 - - Args: - start_date: 开始日期(与count二选一) - end_date: 结束日期(默认当前回测日期) - count: 往前count个交易日(与start_date二选一) - """ - ... - -def get_all_trades_days(date: Optional[str] = ...) -> list[str]: - """获取某日期之前的所有交易日列表 - - Args: - date: 截止日期(默认当前回测日期) - """ - ... - -def get_trading_day(day: int = ...) -> Optional[str]: - """获取当前时间数天前或数天后的交易日期 - - Args: - day: 偏移天数(正数向后,负数向前,0表示当天或上一交易日) - - Returns: - 交易日期字符串,如 '2024-01-15' - """ - ... - -# ==================== 基本面 API ==================== - -def get_fundamentals(stocks: list[str], table: str, fields: list[str], date: Optional[str] = ...) -> pd.DataFrame: - """获取基本面数据 - - Args: - stocks: 股票代码列表 - table: 表名 (valuation/profit_ability/growth_ability/operating_ability/debt_paying_ability) - fields: 字段列表 - date: 查询日期(默认为回测当前日期) - - Returns: - 基本面数据 DataFrame,index 为股票代码 - """ - ... - -# ==================== 行情 API ==================== - -def get_price( - security: str | list[str], - start_date: Optional[str] = ..., - end_date: Optional[str] = ..., - frequency: str = ..., - fields: Optional[str | list[str]] = ..., - fq: Optional[str] = ..., - count: Optional[int] = ... -) -> pd.DataFrame | dict: - """获取历史行情数据 - - Args: - security: 股票代码或代码列表 - start_date: 开始日期 - end_date: 结束日期 - frequency: 频率,默认 '1d' - fields: 字段名或字段列表 - fq: 复权类型 ('pre'-前复权, None-不复权) - count: 获取 count 个数据点 - """ - ... - -def get_history( - count: int, - frequency: str = ..., - field: str | list[str] = ..., - security_list: Optional[str | list[str]] = ..., - fq: Optional[str] = ..., - include: bool = ..., - fill: str = ..., - is_dict: bool = ... -) -> pd.DataFrame | dict: - """获取历史数据 - - Args: - count: 获取多少个数据点 - frequency: 频率,默认 '1d' - field: 字段名或字段列表,默认 'close' - security_list: 股票代码或代码列表 - fq: 复权类型 ('pre'-前复权, None-不复权) - include: 是否包含当前bar,默认 False - fill: 填充方式,默认 'nan' - is_dict: 是否返回字典格式,默认 False - """ - ... - -# ==================== 股票信息 API ==================== - -def get_stock_blocks(stock: str) -> dict: - """获取股票所属板块 - - Args: - stock: 股票代码 - """ - ... - -def get_stock_info(stocks: str | list[str], field: Optional[str | list[str]] = ...) -> dict[str, dict]: - """获取股票基础信息 - - Args: - stocks: 股票代码或代码列表 - field: 字段名或字段列表,如 ['stock_name', 'listed_date'] - """ - ... - -def get_stock_name(stocks: str | list[str]) -> str | dict[str, str]: - """获取股票名称 - - Args: - stocks: 股票代码或代码列表 - - Returns: - 单个股票返回字符串,多个股票返回字典 - """ - ... - -def get_stock_status(stocks: str | list[str], query_type: str = ..., query_date: Optional[str] = ...) -> dict[str, bool]: - """获取股票状态 - - Args: - stocks: 股票代码或代码列表 - query_type: 查询类型 ('ST', 'HALT', 'DELISTING'),默认 'ST' - query_date: 查询日期 - """ - ... - -def get_stock_exrights(stock_code: str, date: Optional[str] = ...) -> Optional[pd.DataFrame]: - """获取股票除权除息信息 - - Args: - stock_code: 股票代码 - date: 查询日期 - """ - ... - -# ==================== 指数/行业 API ==================== - -def get_index_stocks(index_code: str, date: Optional[str] = ...) -> list[str]: - """获取指数成份股 - - Args: - index_code: 指数代码,如 '000300.SS' - date: 查询日期 - """ - ... - -def get_industry_stocks(industry_code: Optional[str] = ...) -> dict | list[str]: - """获取行业成份股 - - Args: - industry_code: 行业代码,None 返回所有行业 - """ - ... - -# ==================== 涨跌停 API ==================== - -def check_limit(security: str | list[str], query_date: Optional[str] = ...) -> dict[str, int]: - """检查涨跌停状态 - - Args: - security: 股票代码或代码列表 - query_date: 查询日期 - - Returns: - {股票代码: 状态} 字典,状态: 1=涨停, -1=跌停, 0=正常 - """ - ... - -# ==================== 交易 API ==================== - -def order(security: str, amount: int, limit_price: Optional[float] = ...) -> Optional[str]: - """买卖指定数量的股票 - - Args: - security: 股票代码 - amount: 交易数量,正数表示买入,负数表示卖出 - limit_price: 买卖限价 - - Returns: - 订单id或None - """ - ... - -def order_target(stock: str, amount: int, limit_price: Optional[float] = ...) -> Optional[str]: - """下单到目标数量 - - Args: - stock: 股票代码 - amount: 期望的最终数量 - limit_price: 买卖限价 - """ - ... - -def order_value(stock: str, value: float, limit_price: Optional[float] = ...) -> Optional[str]: - """按金额下单 - - Args: - stock: 股票代码 - value: 股票价值(元) - limit_price: 买卖限价 - """ - ... - -def order_target_value(stock: str, value: float, limit_price: Optional[float] = ...) -> Optional[str]: - """调整股票持仓市值到目标价值 - - Args: - stock: 股票代码 - value: 期望的股票最终价值(元) - limit_price: 买卖限价 - """ - ... - -def get_open_orders() -> list: - """获取未成交订单""" - ... - -def get_orders(security: Optional[str] = ...) -> list: - """获取当日全部订单 - - Args: - security: 股票代码,None表示获取所有订单 - """ - ... - -def get_order(order_id: str) -> Optional[Any]: - """获取指定订单 - - Args: - order_id: 订单id - """ - ... - -def get_trades() -> list: - """获取当日成交订单""" - ... - -def get_position(security: str) -> Optional[Any]: - """获取持仓信息 - - Args: - security: 股票代码 - """ - ... - -def cancel_order(order: Any) -> bool: - """取消订单 - - Args: - order: Order对象 - """ - ... - -# ==================== 配置 API ==================== - -def set_benchmark(benchmark: str) -> None: - """设置基准(支持指数和普通股票) - - Args: - benchmark: 基准代码,如 '000300.SS' - """ - ... - -def set_universe(stocks: str | list[str]) -> None: - """设置股票池并预加载数据 - - Args: - stocks: 股票代码或代码列表 - """ - ... - -def is_trade() -> bool: - """是否实盘 - - Returns: - 回测环境总是返回 False - """ - ... - -def set_commission(commission_ratio: float = ..., min_commission: float = ..., type: str = ...) -> None: - """设置交易佣金 - - Args: - commission_ratio: 佣金费率,默认万三 - min_commission: 最低佣金,默认5元 - type: 类型,默认 "STOCK" - """ - ... - -def set_slippage(slippage: float = ...) -> None: - """设置滑点 - - Args: - slippage: 滑点比例 - """ - ... - -def set_fixed_slippage(fixedslippage: float = ...) -> None: - """设置固定滑点 - - Args: - fixedslippage: 固定滑点比例,默认 0.001 - """ - ... - -def set_limit_mode(limit_mode: str = ...) -> None: - """设置下单限制模式 - - Args: - limit_mode: 限制模式 ('LIMIT', 'UNLIMITED'),默认 'LIMIT' - """ - ... - -def set_volume_ratio(volume_ratio: float = ...) -> None: - """设置成交比例 - - Args: - volume_ratio: 成交比例,默认0.25 - """ - ... - -def set_yesterday_position(poslist: list[dict]) -> None: - """设置底仓(回测用) - - Args: - poslist: 持仓列表,每个元素为字典 {'security': 股票代码, 'amount': 数量, 'cost_basis': 成本价} - """ - ... - -def run_interval(context: Any, func: Callable, seconds: int = ...) -> None: - """定时运行函数(秒级,仅实盘) - - Args: - context: Context对象 - func: 自定义函数 - seconds: 时间间隔(秒),默认10秒 - """ - ... - -def run_daily(context: Any, func: Callable, time: str = ...) -> None: - """定时运行函数 - - Args: - context: Context对象 - func: 自定义函数 - time: 触发时间,格式HH:MM,默认 '9:31' - """ - ... - -# ==================== 全局对象 ==================== - -class _Global: - """全局变量容器,可存储策略自定义变量""" - def __setattr__(self, name: str, value: Any) -> None: ... - def __getattr__(self, name: str) -> Any: ... - -class _Log: - """日志对象""" - def info(self, msg: str) -> None: ... - def warning(self, msg: str) -> None: ... - def error(self, msg: str) -> None: ... - def debug(self, msg: str) -> None: ... - -class _Context: - """策略上下文对象""" - current_dt: pd.Timestamp - portfolio: Any - benchmark: str - -class _Data: - """市场数据对象""" - def __getitem__(self, key: str) -> dict[str, Any]: ... - -g: _Global -log: _Log -context: _Context -data: _Data