diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..c574757 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,10 @@ import threading import time +import urllib.parse +from collections import deque +from typing import Dict, Deque, Optional +import warnings import requests @@ -58,6 +62,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ + # 0. 频率限制检查(仅在url不为None时生效) + if url: + rate_limiter.acquire(url) + # 1. 获取设置代理 proxies = self.__get_proxies(proxies) # 2. 请求数据结果 @@ -90,4 +98,110 @@ def __get_proxies(self, proxies): return proxies +class RateLimiter: + """ + 频率限制器,基于滑动时间窗口实现域名维度的请求频率控制 + """ + + def __init__(self, default_limit: int = 30, window_seconds: int = 60): + """ + 初始化频率限制器 + + :param default_limit: 默认每分钟请求限制次数 + :param window_seconds: 时间窗口大小(秒) + """ + self.default_limit = default_limit + self.window_seconds = window_seconds + self.domain_limits: Dict[str, int] = {} # 域名 -> 限制次数 + self.request_history: Dict[str, Deque[float]] = {} # 域名 -> 时间戳队列 + self.lock = threading.RLock() + + def set_domain_limit(self, domain: str, limit: int) -> None: + """设置特定域名的请求限制""" + with self.lock: + self.domain_limits[domain] = limit + + def get_domain_limit(self, domain: str) -> int: + """获取特定域名的请求限制,如果没有设置则返回默认限制""" + with self.lock: + return self.domain_limits.get(domain, self.default_limit) + + def extract_domain(self, url: str) -> str: + """从URL中提取域名""" + try: + parsed = urllib.parse.urlparse(url) + domain = parsed.netloc + # 移除端口号 + if ':' in domain: + domain = domain.split(':')[0] + return domain + except Exception: + # 如果URL解析失败,返回一个默认域名 + return 'unknown' + + def _get_domain(self, url: str) -> str: + """内部方法:获取域名(用于测试)""" + return self.extract_domain(url) + + def acquire(self, url: str) -> bool: + """ + 尝试获取请求许可,如果超过频率限制则等待 + + :return: True表示可以继续请求,False表示需要等待 + """ + domain = self.extract_domain(url) + limit = self.get_domain_limit(domain) + + with self.lock: + current_time = time.time() + + # 初始化或清理过期的请求记录 + if domain not in self.request_history: + self.request_history[domain] = deque() + + # 移除时间窗口之外的记录 + while (self.request_history[domain] and + current_time - self.request_history[domain][0] > self.window_seconds): + self.request_history[domain].popleft() + + # 检查是否超过限制 + if len(self.request_history[domain]) >= limit: + # 计算需要等待的时间 + oldest_time = self.request_history[domain][0] + wait_time = self.window_seconds - (current_time - oldest_time) + if wait_time > 0: + time.sleep(wait_time) + # 睡眠后重新清理过期记录 + current_time = time.time() + while (self.request_history[domain] and + current_time - self.request_history[domain][0] > self.window_seconds): + self.request_history[domain].popleft() + + # 添加当前请求时间戳 + self.request_history[domain].append(current_time) + return True + + +# 全局频率限制器实例 +rate_limiter = RateLimiter() + + +def set_rate_limit(domain: Optional[str] = None, limit: int = 30) -> None: + """ + 设置频率限制 + + :param domain: 域名,如果为None则设置默认限制 + :param limit: 每分钟请求次数限制 + """ + if domain is None: + rate_limiter.default_limit = limit + else: + rate_limiter.set_domain_limit(domain, limit) + + +def get_rate_limit(domain: str) -> int: + """获取特定域名的请求限制""" + return rate_limiter.get_domain_limit(domain) + + sun_requests = SunRequests()