Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions adata/common/utils/sunrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import threading
import time
from collections import defaultdict
from urllib.parse import urlparse

import requests

Expand Down Expand Up @@ -41,10 +43,101 @@ def delete(cls, key):
del cls._data[key]


class RateLimiter:
"""
频率限制器:控制同一域名的请求频率
默认每分钟30次请求
"""
_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init()
return cls._instance

def _init(self):
# 每个域名的请求时间戳列表 {domain: [timestamp1, timestamp2, ...]}
self._domain_requests = defaultdict(list)
# 每个域名的频率限制 {domain: max_requests_per_minute}
self._domain_limits = defaultdict(lambda: 30)
self._lock = threading.Lock()

def set_rate_limit(self, domain: str, requests_per_minute: int):
"""
设置指定域名的频率限制
:param domain: 域名,如 'finance.sina.com.cn'
:param requests_per_minute: 每分钟最大请求数
"""
with self._lock:
self._domain_limits[domain] = requests_per_minute

def get_rate_limit(self, domain: str) -> int:
"""获取指定域名的频率限制"""
return self._domain_limits[domain]

def acquire(self, url: str):
"""
获取请求许可,如果超过频率限制则等待
:param url: 请求的URL
"""
domain = self._extract_domain(url)
limit = self._domain_limits[domain]

while True:
with self._lock:
now = time.time()
# 清理60秒前的请求记录
self._domain_requests[domain] = [
ts for ts in self._domain_requests[domain]
if now - ts < 60
]

# 检查是否超过限制
if len(self._domain_requests[domain]) < limit:
# 记录本次请求并放行
self._domain_requests[domain].append(now)
return

# 计算需要等待的时间
oldest_request = self._domain_requests[domain][0]
wait_time = 60 - (now - oldest_request)

# 在锁外等待,避免阻塞其他线程
if wait_time > 0:
time.sleep(wait_time)
else:
# 如果不需要等待,短暂休眠避免CPU空转
time.sleep(0.01)

def _extract_domain(self, url: str) -> str:
"""从URL中提取域名"""
try:
parsed = urlparse(url)
return parsed.netloc if parsed.netloc else url
except Exception:
return url


class SunRequests(object):
def __init__(self, sun_proxy: SunProxy = None) -> None:
super().__init__()
self.sun_proxy = sun_proxy
self._rate_limiter = RateLimiter()

def set_rate_limit(self, domain: str, requests_per_minute: int = 30):
"""
设置指定域名的请求频率限制
:param domain: 域名,如 'finance.sina.com.cn' 或 'https://finance.sina.com.cn'
:param requests_per_minute: 每分钟最大请求数,默认30
"""
# 如果传入的是URL,提取域名
if domain.startswith('http://') or domain.startswith('https://'):
domain = self._rate_limiter._extract_domain(domain)
self._rate_limiter.set_rate_limit(domain, requests_per_minute)

def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies=None, wait_time=None, **kwargs):
"""
Expand All @@ -58,6 +151,10 @@ def request(self, method='get', url=None, times=3, retry_wait_time=1588, proxies
:param kwargs: 其它 requests 参数,用法相同
:return: res
"""
# 0. 频率限制检查
if url:
self._rate_limiter.acquire(url)

# 1. 获取设置代理
proxies = self.__get_proxies(proxies)
# 2. 请求数据结果
Expand Down