Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from adata.__version__ import __version__
from adata.bond import bond
from adata.common.utils.rate_limiter import rate_limiter
from adata.common.utils.sunrequests import SunProxy
from adata.fund import fund
from adata.sentiment import sentiment
Expand All @@ -33,6 +34,40 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None):
return


def set_rate_limit(domain: str = None, requests_per_minute: int = None, enable: bool = None):
"""
设置请求频率限制

:param domain: 域名,例如 'eastmoney.com';如果为None,则设置默认频率限制
:param requests_per_minute: 每分钟最大请求次数;默认30次/分钟
:param enable: 是否启用频率限制,默认True

使用示例:
# 设置默认频率限制为每分钟30次
adata.set_rate_limit(requests_per_minute=30)

# 为特定域名设置频率限制
adata.set_rate_limit(domain='eastmoney.com', requests_per_minute=20)

# 禁用频率限制
adata.set_rate_limit(enable=False)

# 启用频率限制
adata.set_rate_limit(enable=True)
"""
if enable is not None:
if enable:
rate_limiter.enable()
else:
rate_limiter.disable()

if requests_per_minute is not None:
if domain:
rate_limiter.set_limit(domain, requests_per_minute)
else:
rate_limiter.set_default_limit(requests_per_minute)


# set up logging
logger = logging.getLogger("adata")

Expand Down
173 changes: 173 additions & 0 deletions adata/common/utils/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
"""
@desc: 请求频率限制器
@author: adata
@time: 2024
@log: 基于域名级别的请求频率控制
"""

import threading
import time
from collections import defaultdict
from urllib.parse import urlparse


class RateLimiter:
"""
基于滑动窗口算法的请求频率限制器
支持按域名分别控制请求频率
"""
_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init()
return cls._instance

def _init(self):
"""初始化频率限制器"""
# 默认频率限制:每分钟30次
self._default_limit = 30
self._window_size = 60 # 时间窗口:60秒

# 域名级别的频率限制配置 {domain: max_requests_per_minute}
self._domain_limits = {}

# 请求历史记录 {domain: [timestamp1, timestamp2, ...]}
self._request_history = defaultdict(list)

# 线程锁,保证线程安全
self._domain_locks = defaultdict(threading.Lock)

# 是否启用频率限制
self._enabled = True

def set_limit(self, domain: str, requests_per_minute: int):
"""
设置指定域名的频率限制

:param domain: 域名,例如 'eastmoney.com'
:param requests_per_minute: 每分钟最大请求次数
"""
self._domain_limits[domain] = requests_per_minute

def set_default_limit(self, requests_per_minute: int):
"""
设置默认的频率限制

:param requests_per_minute: 每分钟最大请求次数
"""
self._default_limit = requests_per_minute

def get_limit(self, domain: str) -> int:
"""
获取指定域名的频率限制

:param domain: 域名
:return: 该域名的频率限制
"""
# 精确匹配
if domain in self._domain_limits:
return self._domain_limits[domain]

# 尝试匹配父域名
parts = domain.split('.')
for i in range(1, len(parts)):
parent_domain = '.'.join(parts[i:])
if parent_domain in self._domain_limits:
return self._domain_limits[parent_domain]

return self._default_limit

def extract_domain(self, url: str) -> str:
"""
从URL中提取域名

:param url: 请求URL
:return: 域名
"""
try:
parsed = urlparse(url)
return parsed.netloc.lower()
except Exception:
return ""

def acquire(self, url: str):
"""
获取请求许可,如果超过频率限制则等待

:param url: 请求URL
"""
if not self._enabled:
return

domain = self.extract_domain(url)
if not domain:
return

limit = self.get_limit(domain)

with self._domain_locks[domain]:
now = time.time()

# 清理过期的请求记录(超出时间窗口的)
cutoff = now - self._window_size
self._request_history[domain] = [
ts for ts in self._request_history[domain] if ts > cutoff
]

# 检查是否需要等待
while len(self._request_history[domain]) >= limit:
# 计算需要等待的时间
oldest_request = self._request_history[domain][0]
wait_time = oldest_request + self._window_size - now

if wait_time > 0:
# 释放锁,等待,然后重新获取锁
self._domain_locks[domain].release()
time.sleep(wait_time)
self._domain_locks[domain].acquire()
now = time.time()
else:
# 清理过期记录
cutoff = now - self._window_size
self._request_history[domain] = [
ts for ts in self._request_history[domain] if ts > cutoff
]

# 记录本次请求
self._request_history[domain].append(time.time())

def enable(self):
"""启用频率限制"""
self._enabled = True

def disable(self):
"""禁用频率限制"""
self._enabled = False

def is_enabled(self) -> bool:
"""检查频率限制是否启用"""
return self._enabled

def reset(self, domain: str = None):
"""
重置频率限制记录

:param domain: 指定域名,如果为None则重置所有
"""
if domain:
with self._domain_locks[domain]:
if domain in self._request_history:
del self._request_history[domain]
else:
with self._lock:
self._request_history.clear()


# 全局频率限制器实例
rate_limiter = RateLimiter()
15 changes: 12 additions & 3 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import requests

from adata.common.utils.rate_limiter import rate_limiter


class SunProxy(object):
_data = {}
Expand Down Expand Up @@ -46,7 +48,8 @@ def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
self.sun_proxy = sun_proxy

def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs):
def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None,
rate_limit=True, **kwargs):
"""
简单封装的请求,参考requests,增加循环次数和次数之间的等待时间
:param proxies: 代理配置
Expand All @@ -55,12 +58,18 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param times: 次数,int
:param retry_wait_time: 重试等待时间,毫秒
:param wait_time: 等待时间:毫秒;表示每个请求的间隔时间,在请求之前等待sleep,主要用于防止请求太频繁的限制。
:param rate_limit: 是否启用频率限制,默认True
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 1. 获取设置代理
# 1. 频率限制检查
if rate_limit and url:
rate_limiter.acquire(url)

# 2. 获取设置代理
proxies = self.__get_proxies(proxies)
# 2. 请求数据结果

# 3. 请求数据结果
res = None
for i in range(times):
if wait_time:
Expand Down