Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions scripts/data_collector/akshare/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

- [Collector Data](#collector-data)
- [Collector *AKShare* data to qlib](#collector-akshare-data-to-qlib)
- [Using qlib data](#using-qlib-data)


# Collect Data From AKShare (A-share)

> This collector fetches China A-share daily OHLCV data via [AKShare](https://github.com/akfamily/akshare). It covers all Shanghai and Shenzhen stocks (excluding B-shares).

## Requirements

```bash
pip install -r requirements.txt
```

## Collector Data

### Collector *AKShare* data to qlib

> Collect A-share daily data and dump into `qlib` format.

1. Download data to csv: `python scripts/data_collector/akshare/collector.py download_data`

This downloads raw OHLCV data from AKShare to a local directory (one CSV per symbol).

- parameters:
- `source_dir`: save directory
- `start`: start datetime, by default *"2000-01-01"*
- `end`: end datetime, by default today
- `delay`: `time.sleep(delay)`, by default *0.5*
- `max_workers`: number of concurrent workers, by default *1*
- `max_collector_count`: number of retries for failed symbols, by default *2*
- `check_data_length`: minimum row count per symbol, by default `None`
- `limit_nums`: limit number of symbols (for debugging), by default `None*
- `symbols`: comma-separated stock codes, e.g. `"600519,000001"`. If omitted, auto-discovers all A-shares
- `symbol_file`: path to a text file with one code per line
- `adjust`: price adjustment, value from [`qfq`, `hfq`, `""`], by default `qfq`
- examples:
```bash
# all A-shares, daily, forward-adjusted
python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --start 2020-01-01 --end 2024-12-31 --delay 0.5

# specific symbols only
python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --symbols "600519,000001,300750" --start 2024-01-01

# from a symbol file, with no adjustment
python collector.py download_data --source_dir ~/.qlib/stock_data/source/akshare_data --symbol_file symbols.txt --adjust ""
```

2. Normalize data: `python scripts/data_collector/akshare/collector.py normalize_data`

This deduplicates, sorts by date, and aligns to the A-share trading calendar.

- parameters:
- `source_dir`: csv directory
- `normalize_dir`: result directory
- `max_workers`: number of concurrent workers, by default *1*
- `date_field_name`: date column name, by default `date`
- `symbol_field_name`: symbol column name, by default `symbol`
- `end_date`: last date to include (inclusive), by default `None`
- examples:
```bash
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/akshare_data --normalize_dir ~/.qlib/stock_data/source/akshare_1d_nor
```

3. Dump data: `python scripts/dump_bin.py dump_all`

Convert normalized CSV to qlib binary format.

- parameters:
- `data_path`: normalize result directory
- `qlib_dir`: qlib data directory
- `freq`: transaction frequency, by default `day`
- `max_workers`: number of threads, by default *16*
- `exclude_fields`: fields not dumped, by default `""`
- examples:
```bash
python scripts/dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/akshare_1d_nor --qlib_dir ~/.qlib/qlib_data/akshare_data --freq day --exclude_fields date,symbol
```

## Using qlib data

```python
import qlib
from qlib.data import D

qlib.init(provider_uri="~/.qlib/qlib_data/akshare_data", region="cn")
df = D.features(D.instruments("all"), ["$close"], freq="day")
```
188 changes: 188 additions & 0 deletions scripts/data_collector/akshare/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import sys
from pathlib import Path

import fire
import pandas as pd
from loguru import logger

CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))

from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
from data_collector.utils import get_calendar_list

try:
import akshare as ak
except ImportError:
raise ImportError("Please install akshare: pip install akshare")

# A-share symbol suffix rules: SH for 6xx, SZ for others
SH_PREFIX = ("6",)
SZ_PREFIX = ("0", "3")

# B-share prefixes to exclude (USD/HKD denominated, not suitable for A-share workflows)
_B_PREFIXES = ("900", "200")


def get_all_symbols() -> list:
"""Auto-discover all A-share stock codes via akshare.

Returns a sorted list of plain 6-digit codes (e.g. ['000001', '600519']).
B-shares (900xxx, 200xxx) are excluded.
"""
df = ak.stock_info_a_code_name()
codes = df["code"].astype(str).str.zfill(6)
codes = [c for c in codes if not c.startswith(_B_PREFIXES)]
return sorted(set(codes))


def symbol_to_qlib(symbol: str) -> str:
"""Convert plain code like '000001' to Qlib format 'SZ000001'."""
symbol = symbol.strip()
if symbol.startswith(("sh", "sz", "SH", "SZ")):
return symbol[:2].upper() + symbol[2:]
if symbol.startswith(SH_PREFIX):
return f"SH{symbol}"
return f"SZ{symbol}"


def qlib_to_raw(symbol: str) -> str:
"""Convert 'SH600519' or 'SZ000001' to plain '600519' / '000001'."""
symbol = symbol.strip().upper()
for prefix in ("SH", "SZ"):
if symbol.startswith(prefix):
return symbol[len(prefix):]
return symbol


FIELD_MAP = {
"日期": "date",
"股票代码": "symbol",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume",
"成交额": "money",
}


class AKShareCollector(BaseCollector):
def __init__(
self,
save_dir,
start=None,
end=None,
interval="1d",
max_workers=1,
max_collector_count=2,
delay=0.5,
check_data_length=None,
limit_nums=None,
symbols=None,
symbol_file=None,
adjust="qfq",
):
self.requested_symbols = self._parse_symbols(symbols, symbol_file)
self.adjust = adjust
super().__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)

@staticmethod
def _parse_symbols(symbols=None, symbol_file=None):
result = []
if symbols:
if isinstance(symbols, str):
result.extend(s.strip() for s in symbols.split(",") if s.strip())
else:
result.extend(str(s).strip() for s in symbols if str(s).strip())
if symbol_file:
path = Path(symbol_file).expanduser()
result.extend(line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip())
if not result:
logger.info("No symbols provided, auto-discovering A-share universe via akshare...")
result = get_all_symbols()
logger.info(f"Discovered {len(result)} symbols")
return sorted(set(result))

def get_instrument_list(self):
return self.requested_symbols

def normalize_symbol(self, symbol: str):
return symbol_to_qlib(symbol)

def get_data(self, symbol, interval, start_datetime, end_datetime):
raw = qlib_to_raw(symbol)
start_str = pd.Timestamp(start_datetime).strftime("%Y%m%d")
end_str = pd.Timestamp(end_datetime).strftime("%Y%m%d")

try:
df = ak.stock_zh_a_hist(
symbol=raw,
period="daily",
start_date=start_str,
end_date=end_str,
adjust=self.adjust,
)
except Exception as e:
logger.warning(f"AKShare fetch failed for {symbol}: {e}")
return pd.DataFrame()

if df is None or df.empty:
return pd.DataFrame()

df = df.rename(columns=FIELD_MAP)
for col in ("open", "close", "high", "low", "volume", "money"):
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
df["date"] = pd.to_datetime(df["date"])
df["symbol"] = self.normalize_symbol(symbol)

fields = ["date", "symbol", "open", "close", "high", "low", "volume", "money"]
return df.loc[:, [f for f in fields if f in df.columns]]


class AKShareNormalize(BaseNormalize):
def normalize(self, df):
if df.empty:
return df
df = df.copy()
df[self._date_field_name] = pd.to_datetime(df[self._date_field_name])
df = df.drop_duplicates([self._date_field_name]).sort_values(self._date_field_name)
return df

def _get_calendar_list(self):
return get_calendar_list("ALL")


class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
super().__init__(source_dir=source_dir, normalize_dir=normalize_dir, max_workers=max_workers, interval=interval)

@property
def collector_class_name(self):
return "AKShareCollector"

@property
def normalize_class_name(self):
return "AKShareNormalize"

@property
def default_base_dir(self):
return CUR_DIR


if __name__ == "__main__":
fire.Fire(Run)
5 changes: 5 additions & 0 deletions scripts/data_collector/akshare/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
loguru
fire
pandas
tqdm
akshare>=1.16
Loading