diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..2e74a94 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,8 @@ import threading import time +from collections import defaultdict +from urllib.parse import urlparse import requests @@ -41,10 +43,101 @@ def delete(cls, key): del cls._data[key] +class RateLimiter: + """ + 频率限制器:控制同一域名的请求频率 + 默认每分钟30次请求 + """ + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._init() + return cls._instance + + def _init(self): + # 每个域名的请求时间戳列表 {domain: [timestamp1, timestamp2, ...]} + self._domain_requests = defaultdict(list) + # 每个域名的频率限制 {domain: max_requests_per_minute} + self._domain_limits = defaultdict(lambda: 30) + self._lock = threading.Lock() + + def set_rate_limit(self, domain: str, requests_per_minute: int): + """ + 设置指定域名的频率限制 + :param domain: 域名,如 'finance.sina.com.cn' + :param requests_per_minute: 每分钟最大请求数 + """ + with self._lock: + self._domain_limits[domain] = requests_per_minute + + def get_rate_limit(self, domain: str) -> int: + """获取指定域名的频率限制""" + return self._domain_limits[domain] + + def acquire(self, url: str): + """ + 获取请求许可,如果超过频率限制则等待 + :param url: 请求的URL + """ + domain = self._extract_domain(url) + limit = self._domain_limits[domain] + + while True: + with self._lock: + now = time.time() + # 清理60秒前的请求记录 + self._domain_requests[domain] = [ + ts for ts in self._domain_requests[domain] + if now - ts < 60 + ] + + # 检查是否超过限制 + if len(self._domain_requests[domain]) < limit: + # 记录本次请求并放行 + self._domain_requests[domain].append(now) + return + + # 计算需要等待的时间 + oldest_request = self._domain_requests[domain][0] + wait_time = 60 - (now - oldest_request) + + # 在锁外等待,避免阻塞其他线程 + if wait_time > 0: + time.sleep(wait_time) + else: + # 如果不需要等待,短暂休眠避免CPU空转 + time.sleep(0.01) + + def _extract_domain(self, url: str) -> str: + """从URL中提取域名""" + try: + parsed = urlparse(url) + return parsed.netloc if parsed.netloc else url + except Exception: + return url + + class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + self._rate_limiter = RateLimiter() + + def set_rate_limit(self, domain: str, requests_per_minute: int = 30): + """ + 设置指定域名的请求频率限制 + :param domain: 域名,如 'finance.sina.com.cn' 或 'https://finance.sina.com.cn' + :param requests_per_minute: 每分钟最大请求数,默认30 + """ + # 如果传入的是URL,提取域名 + if domain.startswith('http://') or domain.startswith('https://'): + domain = self._rate_limiter._extract_domain(domain) + self._rate_limiter.set_rate_limit(domain, requests_per_minute) def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): """ @@ -58,6 +151,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ + # 0. 频率限制检查 + if url: + self._rate_limiter.acquire(url) + # 1. 获取设置代理 proxies = self.__get_proxies(proxies) # 2. 请求数据结果