diff --git a/adata/__init__.py b/adata/__init__.py index dee08e2..33884bc 100644 --- a/adata/__init__.py +++ b/adata/__init__.py @@ -10,6 +10,7 @@ from adata.__version__ import __version__ from adata.bond import bond +from adata.common.utils.rate_limiter import rate_limiter from adata.common.utils.sunrequests import SunProxy from adata.fund import fund from adata.sentiment import sentiment @@ -33,6 +34,40 @@ def proxy(is_proxy=False, ip: str = None, proxy_url: str = None): return +def set_rate_limit(domain: str = None, requests_per_minute: int = None, enable: bool = None): + """ + 设置请求频率限制 + + :param domain: 域名,例如 'eastmoney.com';如果为None,则设置默认频率限制 + :param requests_per_minute: 每分钟最大请求次数;默认30次/分钟 + :param enable: 是否启用频率限制,默认True + + 使用示例: + # 设置默认频率限制为每分钟30次 + adata.set_rate_limit(requests_per_minute=30) + + # 为特定域名设置频率限制 + adata.set_rate_limit(domain='eastmoney.com', requests_per_minute=20) + + # 禁用频率限制 + adata.set_rate_limit(enable=False) + + # 启用频率限制 + adata.set_rate_limit(enable=True) + """ + if enable is not None: + if enable: + rate_limiter.enable() + else: + rate_limiter.disable() + + if requests_per_minute is not None: + if domain: + rate_limiter.set_limit(domain, requests_per_minute) + else: + rate_limiter.set_default_limit(requests_per_minute) + + # set up logging logger = logging.getLogger("adata") diff --git a/adata/common/utils/rate_limiter.py b/adata/common/utils/rate_limiter.py new file mode 100644 index 0000000..8fe932e --- /dev/null +++ b/adata/common/utils/rate_limiter.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +@desc: 请求频率限制器 +@author: adata +@time: 2024 +@log: 基于域名级别的请求频率控制 +""" + +import threading +import time +from collections import defaultdict +from urllib.parse import urlparse + + +class RateLimiter: + """ + 基于滑动窗口算法的请求频率限制器 + 支持按域名分别控制请求频率 + """ + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._init() + return cls._instance + + def _init(self): + """初始化频率限制器""" + # 默认频率限制:每分钟30次 + self._default_limit = 30 + self._window_size = 60 # 时间窗口:60秒 + + # 域名级别的频率限制配置 {domain: max_requests_per_minute} + self._domain_limits = {} + + # 请求历史记录 {domain: [timestamp1, timestamp2, ...]} + self._request_history = defaultdict(list) + + # 线程锁,保证线程安全 + self._domain_locks = defaultdict(threading.Lock) + + # 是否启用频率限制 + self._enabled = True + + def set_limit(self, domain: str, requests_per_minute: int): + """ + 设置指定域名的频率限制 + + :param domain: 域名,例如 'eastmoney.com' + :param requests_per_minute: 每分钟最大请求次数 + """ + self._domain_limits[domain] = requests_per_minute + + def set_default_limit(self, requests_per_minute: int): + """ + 设置默认的频率限制 + + :param requests_per_minute: 每分钟最大请求次数 + """ + self._default_limit = requests_per_minute + + def get_limit(self, domain: str) -> int: + """ + 获取指定域名的频率限制 + + :param domain: 域名 + :return: 该域名的频率限制 + """ + # 精确匹配 + if domain in self._domain_limits: + return self._domain_limits[domain] + + # 尝试匹配父域名 + parts = domain.split('.') + for i in range(1, len(parts)): + parent_domain = '.'.join(parts[i:]) + if parent_domain in self._domain_limits: + return self._domain_limits[parent_domain] + + return self._default_limit + + def extract_domain(self, url: str) -> str: + """ + 从URL中提取域名 + + :param url: 请求URL + :return: 域名 + """ + try: + parsed = urlparse(url) + return parsed.netloc.lower() + except Exception: + return "" + + def acquire(self, url: str): + """ + 获取请求许可,如果超过频率限制则等待 + + :param url: 请求URL + """ + if not self._enabled: + return + + domain = self.extract_domain(url) + if not domain: + return + + limit = self.get_limit(domain) + + with self._domain_locks[domain]: + now = time.time() + + # 清理过期的请求记录(超出时间窗口的) + cutoff = now - self._window_size + self._request_history[domain] = [ + ts for ts in self._request_history[domain] if ts > cutoff + ] + + # 检查是否需要等待 + while len(self._request_history[domain]) >= limit: + # 计算需要等待的时间 + oldest_request = self._request_history[domain][0] + wait_time = oldest_request + self._window_size - now + + if wait_time > 0: + # 释放锁,等待,然后重新获取锁 + self._domain_locks[domain].release() + time.sleep(wait_time) + self._domain_locks[domain].acquire() + now = time.time() + else: + # 清理过期记录 + cutoff = now - self._window_size + self._request_history[domain] = [ + ts for ts in self._request_history[domain] if ts > cutoff + ] + + # 记录本次请求 + self._request_history[domain].append(time.time()) + + def enable(self): + """启用频率限制""" + self._enabled = True + + def disable(self): + """禁用频率限制""" + self._enabled = False + + def is_enabled(self) -> bool: + """检查频率限制是否启用""" + return self._enabled + + def reset(self, domain: str = None): + """ + 重置频率限制记录 + + :param domain: 指定域名,如果为None则重置所有 + """ + if domain: + with self._domain_locks[domain]: + if domain in self._request_history: + del self._request_history[domain] + else: + with self._lock: + self._request_history.clear() + + +# 全局频率限制器实例 +rate_limiter = RateLimiter() diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..6106c57 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -13,6 +13,8 @@ import requests +from adata.common.utils.rate_limiter import rate_limiter + class SunProxy(object): _data = {} @@ -46,7 +48,8 @@ def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy - def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): + def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, + rate_limit=True, **kwargs): """ 简单封装的请求,参考requests,增加循环次数和次数之间的等待时间 :param proxies: 代理配置 @@ -55,12 +58,18 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param times: 次数,int :param retry_wait_time: 重试等待时间,毫秒 :param wait_time: 等待时间:毫秒;表示每个请求的间隔时间,在请求之前等待sleep,主要用于防止请求太频繁的限制。 + :param rate_limit: 是否启用频率限制,默认True :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + # 1. 频率限制检查 + if rate_limit and url: + rate_limiter.acquire(url) + + # 2. 获取设置代理 proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 + + # 3. 请求数据结果 res = None for i in range(times): if wait_time: