From 83d089bff90af711a97f2149ea2fbbd540c93484 Mon Sep 17 00:00:00 2001 From: z1369810535-sketch Date: Mon, 11 May 2026 11:23:27 +0800 Subject: [PATCH] feat(collector): add AKShare A-share data collector New collector under scripts/data_collector/akshare/ for fetching China A-share daily OHLCV data via AKShare. Covers all SH/SZ stocks (excl. B-shares), with auto-discovery of the full instrument universe, trading calendar alignment via shared utils.get_calendar_list("ALL"), and 28 unit tests (mock-only, no network dependency). --- scripts/data_collector/akshare/README.md | 90 ++++++ scripts/data_collector/akshare/collector.py | 188 +++++++++++ .../data_collector/akshare/requirements.txt | 5 + .../data_collector/akshare/test_collector.py | 292 ++++++++++++++++++ 4 files changed, 575 insertions(+) create mode 100644 scripts/data_collector/akshare/README.md create mode 100644 scripts/data_collector/akshare/collector.py create mode 100644 scripts/data_collector/akshare/requirements.txt create mode 100644 scripts/data_collector/akshare/test_collector.py diff --git a/scripts/data_collector/akshare/README.md b/scripts/data_collector/akshare/README.md new file mode 100644 index 0000000000..058591dd14 --- /dev/null +++ b/scripts/data_collector/akshare/README.md @@ -0,0 +1,90 @@ + +- [Collector Data](#collector-data) + - [Collector *AKShare* data to qlib](#collector-akshare-data-to-qlib) +- [Using qlib data](#using-qlib-data) + + +# Collect Data From AKShare (A-share) + +> This collector fetches China A-share daily OHLCV data via [AKShare](https://github.com/akfamily/akshare). It covers all Shanghai and Shenzhen stocks (excluding B-shares). + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +### Collector *AKShare* data to qlib + +> Collect A-share daily data and dump into `qlib` format. + +1. Download data to csv: `python scripts/data_collector/akshare/collector.py download_data` + + This downloads raw OHLCV data from AKShare to a local directory (one CSV per symbol). + + - parameters: + - `source_dir`: save directory + - `start`: start datetime, by default *"2000-01-01"* + - `end`: end datetime, by default today + - `delay`: `time.sleep(delay)`, by default *0.5* + - `max_workers`: number of concurrent workers, by default *1* + - `max_collector_count`: number of retries for failed symbols, by default *2* + - `check_data_length`: minimum row count per symbol, by default `None` + - `limit_nums`: limit number of symbols (for debugging), by default `None* + - `symbols`: comma-separated stock codes, e.g. `"600519,000001"`. If omitted, auto-discovers all A-shares + - `symbol_file`: path to a text file with one code per line + - `adjust`: price adjustment, value from [`qfq`, `hfq`, `""`], by default `qfq` + - examples: + ```bash + # all A-shares, daily, forward-adjusted + python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --start 2020-01-01 --end 2024-12-31 --delay 0.5 + + # specific symbols only + python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --symbols "600519,000001,300750" --start 2024-01-01 + + # from a symbol file, with no adjustment + python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --symbol_file symbols.txt --adjust "" + ``` + +2. Normalize data: `python scripts/data_collector/akshare/collector.py normalize_data` + + This deduplicates, sorts by date, and aligns to the A-share trading calendar. + + - parameters: + - `source_dir`: csv directory + - `normalize_dir`: result directory + - `max_workers`: number of concurrent workers, by default *1* + - `date_field_name`: date column name, by default `date` + - `symbol_field_name`: symbol column name, by default `symbol` + - `end_date`: last date to include (inclusive), by default `None` + - examples: + ```bash + python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/akshare_data --normalize_dir ~/.qlib/stock_data/source/akshare_1d_nor + ``` + +3. Dump data: `python scripts/dump_bin.py dump_all` + + Convert normalized CSV to qlib binary format. + + - parameters: + - `data_path`: normalize result directory + - `qlib_dir`: qlib data directory + - `freq`: transaction frequency, by default `day` + - `max_workers`: number of threads, by default *16* + - `exclude_fields`: fields not dumped, by default `""` + - examples: + ```bash + python scripts/dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/akshare_1d_nor --qlib_dir ~/.qlib/qlib_data/akshare_data --freq day --exclude_fields date,symbol + ``` + +## Using qlib data + + ```python + import qlib + from qlib.data import D + + qlib.init(provider_uri="~/.qlib/qlib_data/akshare_data", region="cn") + df = D.features(D.instruments("all"), ["$close"], freq="day") + ``` diff --git a/scripts/data_collector/akshare/collector.py b/scripts/data_collector/akshare/collector.py new file mode 100644 index 0000000000..9171f002ab --- /dev/null +++ b/scripts/data_collector/akshare/collector.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import fire +import pandas as pd +from loguru import logger + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + +from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize +from data_collector.utils import get_calendar_list + +try: + import akshare as ak +except ImportError: + raise ImportError("Please install akshare: pip install akshare") + +# A-share symbol suffix rules: SH for 6xx, SZ for others +SH_PREFIX = ("6",) +SZ_PREFIX = ("0", "3") + +# B-share prefixes to exclude (USD/HKD denominated, not suitable for A-share workflows) +_B_PREFIXES = ("900", "200") + + +def get_all_symbols() -> list: + """Auto-discover all A-share stock codes via akshare. + + Returns a sorted list of plain 6-digit codes (e.g. ['000001', '600519']). + B-shares (900xxx, 200xxx) are excluded. + """ + df = ak.stock_info_a_code_name() + codes = df["code"].astype(str).str.zfill(6) + codes = [c for c in codes if not c.startswith(_B_PREFIXES)] + return sorted(set(codes)) + + +def symbol_to_qlib(symbol: str) -> str: + """Convert plain code like '000001' to Qlib format 'SZ000001'.""" + symbol = symbol.strip() + if symbol.startswith(("sh", "sz", "SH", "SZ")): + return symbol[:2].upper() + symbol[2:] + if symbol.startswith(SH_PREFIX): + return f"SH{symbol}" + return f"SZ{symbol}" + + +def qlib_to_raw(symbol: str) -> str: + """Convert 'SH600519' or 'SZ000001' to plain '600519' / '000001'.""" + symbol = symbol.strip().upper() + for prefix in ("SH", "SZ"): + if symbol.startswith(prefix): + return symbol[len(prefix):] + return symbol + + +FIELD_MAP = { + "日期": "date", + "股票代码": "symbol", + "开盘": "open", + "收盘": "close", + "最高": "high", + "最低": "low", + "成交量": "volume", + "成交额": "money", +} + + +class AKShareCollector(BaseCollector): + def __init__( + self, + save_dir, + start=None, + end=None, + interval="1d", + max_workers=1, + max_collector_count=2, + delay=0.5, + check_data_length=None, + limit_nums=None, + symbols=None, + symbol_file=None, + adjust="qfq", + ): + self.requested_symbols = self._parse_symbols(symbols, symbol_file) + self.adjust = adjust + super().__init__( + save_dir=save_dir, + start=start, + end=end, + interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, + delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) + + @staticmethod + def _parse_symbols(symbols=None, symbol_file=None): + result = [] + if symbols: + if isinstance(symbols, str): + result.extend(s.strip() for s in symbols.split(",") if s.strip()) + else: + result.extend(str(s).strip() for s in symbols if str(s).strip()) + if symbol_file: + path = Path(symbol_file).expanduser() + result.extend(line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()) + if not result: + logger.info("No symbols provided, auto-discovering A-share universe via akshare...") + result = get_all_symbols() + logger.info(f"Discovered {len(result)} symbols") + return sorted(set(result)) + + def get_instrument_list(self): + return self.requested_symbols + + def normalize_symbol(self, symbol: str): + return symbol_to_qlib(symbol) + + def get_data(self, symbol, interval, start_datetime, end_datetime): + raw = qlib_to_raw(symbol) + start_str = pd.Timestamp(start_datetime).strftime("%Y%m%d") + end_str = pd.Timestamp(end_datetime).strftime("%Y%m%d") + + try: + df = ak.stock_zh_a_hist( + symbol=raw, + period="daily", + start_date=start_str, + end_date=end_str, + adjust=self.adjust, + ) + except Exception as e: + logger.warning(f"AKShare fetch failed for {symbol}: {e}") + return pd.DataFrame() + + if df is None or df.empty: + return pd.DataFrame() + + df = df.rename(columns=FIELD_MAP) + for col in ("open", "close", "high", "low", "volume", "money"): + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + df["date"] = pd.to_datetime(df["date"]) + df["symbol"] = self.normalize_symbol(symbol) + + fields = ["date", "symbol", "open", "close", "high", "low", "volume", "money"] + return df.loc[:, [f for f in fields if f in df.columns]] + + +class AKShareNormalize(BaseNormalize): + def normalize(self, df): + if df.empty: + return df + df = df.copy() + df[self._date_field_name] = pd.to_datetime(df[self._date_field_name]) + df = df.drop_duplicates([self._date_field_name]).sort_values(self._date_field_name) + return df + + def _get_calendar_list(self): + return get_calendar_list("ALL") + + +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): + super().__init__(source_dir=source_dir, normalize_dir=normalize_dir, max_workers=max_workers, interval=interval) + + @property + def collector_class_name(self): + return "AKShareCollector" + + @property + def normalize_class_name(self): + return "AKShareNormalize" + + @property + def default_base_dir(self): + return CUR_DIR + + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/akshare/requirements.txt b/scripts/data_collector/akshare/requirements.txt new file mode 100644 index 0000000000..d6e5bf4eb3 --- /dev/null +++ b/scripts/data_collector/akshare/requirements.txt @@ -0,0 +1,5 @@ +loguru +fire +pandas +tqdm +akshare>=1.16 diff --git a/scripts/data_collector/akshare/test_collector.py b/scripts/data_collector/akshare/test_collector.py new file mode 100644 index 0000000000..3aadbbda6f --- /dev/null +++ b/scripts/data_collector/akshare/test_collector.py @@ -0,0 +1,292 @@ +"""Unit tests for the AKShare collector. + +Run from the repo root: + + python -m pytest scripts/data_collector/akshare/test_collector.py -v +""" + +import sys +from pathlib import Path + +import pandas as pd +import pytest + +CUR_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(CUR_DIR.parent.parent)) + +from data_collector.akshare import collector as akshare_collector +from data_collector.akshare.collector import ( + AKShareCollector, + AKShareNormalize, + get_all_symbols, + qlib_to_raw, + symbol_to_qlib, +) + + +@pytest.mark.parametrize( + "raw, expected", + [ + ("600519", "SH600519"), + ("000001", "SZ000001"), + ("300750", "SZ300750"), + ("sh600519", "SH600519"), + ("sz000001", "SZ000001"), + ("SH600519", "SH600519"), + (" 000001 ", "SZ000001"), + ], +) +def test_symbol_to_qlib(raw, expected): + assert symbol_to_qlib(raw) == expected + + +@pytest.mark.parametrize( + "qlib_symbol, expected", + [ + ("SH600519", "600519"), + ("SZ000001", "000001"), + ("sh600519", "600519"), + ("600519", "600519"), + ], +) +def test_qlib_to_raw(qlib_symbol, expected): + assert qlib_to_raw(qlib_symbol) == expected + + +def _make_collector(tmp_path, **kwargs): + defaults = dict( + save_dir=tmp_path, + start="2024-01-01", + end="2024-01-10", + symbols="600519,000001", + ) + defaults.update(kwargs) + return AKShareCollector(**defaults) + + +def test_parse_symbols_string(tmp_path): + c = _make_collector(tmp_path, symbols="600519, 000001 ,000001") + assert c.requested_symbols == ["000001", "600519"] + + +def test_parse_symbols_list(tmp_path): + c = _make_collector(tmp_path, symbols=["600519", "000001"]) + assert c.requested_symbols == ["000001", "600519"] + + +def test_parse_symbols_from_file(tmp_path): + sym_file = tmp_path / "symbols.txt" + sym_file.write_text("600519\n000001\n\n300750\n", encoding="utf-8") + c = _make_collector(tmp_path, symbols=None, symbol_file=str(sym_file)) + assert c.requested_symbols == ["000001", "300750", "600519"] + + +def test_parse_symbols_combined(tmp_path): + sym_file = tmp_path / "symbols.txt" + sym_file.write_text("600519\n", encoding="utf-8") + c = _make_collector(tmp_path, symbols="000001", symbol_file=str(sym_file)) + assert c.requested_symbols == ["000001", "600519"] + + +def test_parse_symbols_no_args_triggers_discover(tmp_path, monkeypatch): + """When no --symbols or --symbol_file given, _parse_symbols auto-discovers.""" + fake_df = pd.DataFrame({"code": ["600519", "000001"], "name": ["a", "b"]}) + monkeypatch.setattr(akshare_collector.ak, "stock_info_a_code_name", lambda: fake_df) + c = AKShareCollector(save_dir=tmp_path, start="2024-01-01", end="2024-01-10") + assert c.requested_symbols == ["000001", "600519"] + + +def test_normalize_symbol_routes_to_helper(tmp_path): + c = _make_collector(tmp_path) + assert c.normalize_symbol("600519") == "SH600519" + assert c.normalize_symbol("000001") == "SZ000001" + + +def test_get_instrument_list_returns_requested(tmp_path): + c = _make_collector(tmp_path, symbols="600519,000001") + assert c.get_instrument_list() == ["000001", "600519"] + + +def _fake_akshare_df(): + return pd.DataFrame( + { + "日期": ["2024-01-02", "2024-01-03"], + "股票代码": ["600519", "600519"], + "开盘": ["1700.0", "1710.0"], + "收盘": ["1705.0", "1715.0"], + "最高": ["1720.0", "1725.0"], + "最低": ["1690.0", "1700.0"], + "成交量": ["100000", "120000"], + "成交额": ["170500000", "172500000"], + "振幅": ["1.0", "1.1"], + } + ) + + +def test_get_data_success(tmp_path, monkeypatch): + captured = {} + + def fake_hist(symbol, period, start_date, end_date, adjust): + captured.update( + symbol=symbol, + period=period, + start_date=start_date, + end_date=end_date, + adjust=adjust, + ) + return _fake_akshare_df() + + monkeypatch.setattr(akshare_collector.ak, "stock_zh_a_hist", fake_hist) + + c = _make_collector(tmp_path, symbols="600519") + df = c.get_data( + "SH600519", + interval="1d", + start_datetime=pd.Timestamp("2024-01-01"), + end_datetime=pd.Timestamp("2024-01-10"), + ) + + assert captured == { + "symbol": "600519", + "period": "daily", + "start_date": "20240101", + "end_date": "20240110", + "adjust": "qfq", + } + assert list(df.columns) == ["date", "symbol", "open", "close", "high", "low", "volume", "money"] + assert df["symbol"].unique().tolist() == ["SH600519"] + assert df["open"].dtype.kind == "f" + assert df["volume"].dtype.kind in ("f", "i") + assert pd.api.types.is_datetime64_any_dtype(df["date"]) + assert len(df) == 2 + + +def test_get_data_handles_empty_response(tmp_path, monkeypatch): + monkeypatch.setattr( + akshare_collector.ak, "stock_zh_a_hist", lambda **_: pd.DataFrame() + ) + c = _make_collector(tmp_path, symbols="600519") + df = c.get_data( + "SH600519", + interval="1d", + start_datetime=pd.Timestamp("2024-01-01"), + end_datetime=pd.Timestamp("2024-01-10"), + ) + assert df.empty + + +def test_get_data_swallows_akshare_exception(tmp_path, monkeypatch): + def boom(**_): + raise RuntimeError("upstream HTTP 500") + + monkeypatch.setattr(akshare_collector.ak, "stock_zh_a_hist", boom) + c = _make_collector(tmp_path, symbols="600519") + df = c.get_data( + "SH600519", + interval="1d", + start_datetime=pd.Timestamp("2024-01-01"), + end_datetime=pd.Timestamp("2024-01-10"), + ) + assert df.empty + + +def test_get_data_passes_through_adjust(tmp_path, monkeypatch): + captured = {} + + def fake_hist(**kwargs): + captured.update(kwargs) + return _fake_akshare_df() + + monkeypatch.setattr(akshare_collector.ak, "stock_zh_a_hist", fake_hist) + c = _make_collector(tmp_path, symbols="600519", adjust="hfq") + c.get_data( + "SH600519", + interval="1d", + start_datetime=pd.Timestamp("2024-01-01"), + end_datetime=pd.Timestamp("2024-01-10"), + ) + assert captured["adjust"] == "hfq" + + +class _NormalizeNoCalendar(AKShareNormalize): + def __init__(self): + self._date_field_name = "date" + self._symbol_field_name = "symbol" + self._calendar_list = [] + + +def test_normalize_dedups_and_sorts(): + raw = pd.DataFrame( + { + "date": ["2024-01-03", "2024-01-02", "2024-01-02"], + "symbol": ["SH600519"] * 3, + "close": [1715.0, 1705.0, 1705.0], + } + ) + out = _NormalizeNoCalendar().normalize(raw) + assert len(out) == 2 + assert out["date"].is_monotonic_increasing + assert pd.api.types.is_datetime64_any_dtype(out["date"]) + + +def test_normalize_passthrough_empty(): + out = _NormalizeNoCalendar().normalize(pd.DataFrame()) + assert out.empty + + +# --- get_all_symbols tests --- + + +def test_get_all_symbols_filters_bshares(monkeypatch): + fake_df = pd.DataFrame( + { + "code": ["600519", "000001", "300750", "900901", "200002", "688001"], + "name": ["贵州茅台", "平安银行", "宁德时代", "dummy", "dummy", "dummy"], + } + ) + monkeypatch.setattr(akshare_collector.ak, "stock_info_a_code_name", lambda: fake_df) + result = get_all_symbols() + assert "600519" in result + assert "000001" in result + assert "300750" in result + assert "688001" in result # STAR board included + assert "900901" not in result # B-share excluded + assert "200002" not in result # B-share excluded + + +def test_get_all_symbols_deduplicates(monkeypatch): + fake_df = pd.DataFrame({"code": ["600519", "600519", "000001"], "name": ["a", "b", "c"]}) + monkeypatch.setattr(akshare_collector.ak, "stock_info_a_code_name", lambda: fake_df) + result = get_all_symbols() + assert result == ["000001", "600519"] + + +# --- _parse_symbols auto-discover fallback --- + + +def test_parse_symbols_auto_discover(tmp_path, monkeypatch): + fake_df = pd.DataFrame( + {"code": ["600519", "000001"], "name": ["贵州茅台", "平安银行"]} + ) + monkeypatch.setattr(akshare_collector.ak, "stock_info_a_code_name", lambda: fake_df) + c = AKShareCollector(save_dir=tmp_path, start="2024-01-01", end="2024-01-10") + assert "000001" in c.requested_symbols + assert "600519" in c.requested_symbols + + +# --- _get_calendar_list test --- + + +def test_get_calendar_list_uses_shared_utils(monkeypatch): + call_args = {} + + def fake_get_calendar_list(bench_code): + call_args["bench_code"] = bench_code + return [pd.Timestamp("2024-01-02"), pd.Timestamp("2024-01-03")] + + monkeypatch.setattr( + "data_collector.akshare.collector.get_calendar_list", fake_get_calendar_list + ) + norm = AKShareNormalize() + assert call_args["bench_code"] == "ALL" + assert len(norm._calendar_list) == 2