diff --git a/adata/common/utils/sunrequests.py b/adata/common/utils/sunrequests.py index 9fec0ed..fdc164b 100644 --- a/adata/common/utils/sunrequests.py +++ b/adata/common/utils/sunrequests.py @@ -10,6 +10,7 @@ import threading import time +from urllib.parse import urlparse import requests @@ -41,10 +42,72 @@ def delete(cls, key): del cls._data[key] +class RateLimiter: + """ + 频率限制器,按域名限制请求次数 + """ + _instance_lock = threading.Lock() + _domain_limits = {} + _domain_requests = {} + + def __new__(cls, *args, **kwargs): + if not hasattr(RateLimiter, "_instance"): + with RateLimiter._instance_lock: + if not hasattr(RateLimiter, "_instance"): + RateLimiter._instance = object.__new__(cls) + return RateLimiter._instance + + @classmethod + def set_limit(cls, domain, limit=30): + """ + 设置域名的请求限制 + :param domain: 域名 + :param limit: 每分钟请求次数限制 + """ + cls._domain_limits[domain] = limit + + @classmethod + def get_limit(cls, domain): + """ + 获取域名的请求限制 + :param domain: 域名 + :return: 每分钟请求次数限制 + """ + return cls._domain_limits.get(domain, 30) + + @classmethod + def check_and_wait(cls, domain): + """ + 检查并等待,确保不超过频率限制 + :param domain: 域名 + """ + current_time = int(time.time() / 60) # 当前分钟 + if domain not in cls._domain_requests: + cls._domain_requests[domain] = {} + if current_time not in cls._domain_requests[domain]: + cls._domain_requests[domain][current_time] = 0 + + limit = cls.get_limit(domain) + while cls._domain_requests[domain][current_time] >= limit: + time.sleep(0.1) + current_time = int(time.time() / 60) + if current_time not in cls._domain_requests[domain]: + cls._domain_requests[domain][current_time] = 0 + + # 增加请求计数 + cls._domain_requests[domain][current_time] += 1 + + # 清理过期的记录 + expired_times = [t for t in cls._domain_requests[domain] if t < current_time - 1] + for t in expired_times: + del cls._domain_requests[domain][t] + + class SunRequests(object): def __init__(self, sun_proxy: SunProxy = None) -> None: super().__init__() self.sun_proxy = sun_proxy + self.rate_limiter = RateLimiter() def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs): """ @@ -58,9 +121,15 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies :param kwargs: 其它 requests 参数,用法相同 :return: res """ - # 1. 获取设置代理 + # 1. 解析域名并检查频率限制 + if url: + parsed_url = urlparse(url) + domain = parsed_url.netloc + self.rate_limiter.check_and_wait(domain) + + # 2. 获取设置代理 proxies = self.__get_proxies(proxies) - # 2. 请求数据结果 + # 3. 请求数据结果 res = None for i in range(times): if wait_time: @@ -73,6 +142,14 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies return res return res + def set_rate_limit(self, domain, limit=30): + """ + 设置域名的请求频率限制 + :param domain: 域名 + :param limit: 每分钟请求次数限制 + """ + self.rate_limiter.set_limit(domain, limit) + def __get_proxies(self, proxies): """ 获取代理配置 @@ -83,6 +160,11 @@ def __get_proxies(self, proxies): ip = SunProxy.get('ip') proxy_url = SunProxy.get('proxy_url') if not ip and is_proxy and proxy_url: + # 这里也需要检查频率限制 + parsed_url = urlparse(proxy_url) + domain = parsed_url.netloc + self.rate_limiter.check_and_wait(domain) + ip = requests.get(url=proxy_url).text.replace('\r\n', '') \ .replace('\r', '').replace('\n', '').replace('\t', '') if is_proxy and ip: diff --git a/test_basic_function.py b/test_basic_function.py new file mode 100644 index 0000000..ad275db --- /dev/null +++ b/test_basic_function.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" +测试基本功能是否正常 +""" + +import sys +import os + +# 直接导入sunrequests模块 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'adata', 'common', 'utils')) +from sunrequests import sun_requests as requests + +# 测试基本请求功能 +def test_basic_request(): + print("测试基本请求功能...") + try: + # 发送一个简单的请求 + response = requests.request('get', 'https://www.baidu.com') + print(f"请求成功,状态码: {response.status_code}") + print("测试完成") + except Exception as e: + print(f"测试失败: {e}") + +if __name__ == "__main__": + test_basic_request() diff --git a/test_rate_limit.py b/test_rate_limit.py new file mode 100644 index 0000000..cff5592 --- /dev/null +++ b/test_rate_limit.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +""" +测试频率限制功能 +""" + +import time +import sys +import os + +# 直接导入sunrequests模块 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'adata', 'common', 'utils')) +from sunrequests import sun_requests as requests + +# 测试同一域名的请求频率限制 +def test_rate_limit(): + print("测试频率限制功能...") + start_time = time.time() + domain = "api.example.com" + + # 设置测试域名的限制为10次/分钟 + requests.set_rate_limit(domain, 10) + + # 发送11个请求,应该在第11个请求时被限制 + for i in range(1, 12): + print(f"发送第{i}个请求...") + # 使用模拟URL,不实际发送请求 + try: + response = requests.request('get', f"https://{domain}/test") + except Exception as e: + # 忽略请求失败的错误,只关注频率限制 + pass + print(f"请求{i}完成,耗时: {time.time() - start_time:.2f}秒") + start_time = time.time() + + print("测试完成") + +if __name__ == "__main__": + test_rate_limit()