diff --git a/src/spiders/engine.rs b/src/spiders/engine.rs index 6fbab41..2e3c250 100644 --- a/src/spiders/engine.rs +++ b/src/spiders/engine.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use tokio::sync::{Mutex, Semaphore}; use url::Url; @@ -22,6 +23,9 @@ pub struct CrawlerEngine { stats: Arc>, items: Arc>, global_limiter: Arc, + /// Per-domain semaphores, lazily created on first request to each host. + /// Used to enforce `Spider::concurrent_requests_per_domain`. + domain_limiters: Arc>>>, robots_manager: Option>>, cache: Option>, checkpoint: Option>, @@ -74,6 +78,7 @@ impl CrawlerEngine { stats: Arc::new(Mutex::new(CrawlStats::default())), items: Arc::new(Mutex::new(ItemList::new())), global_limiter: Arc::new(Semaphore::new(concurrent as usize)), + domain_limiters: Arc::new(Mutex::new(HashMap::new())), robots_manager, cache, checkpoint, @@ -164,6 +169,23 @@ impl CrawlerEngine { } let permit = permit.unwrap(); + // Acquire a per-domain permit when a per-domain cap is + // configured, so a single host cannot exceed it. + let per_domain = self.spider.concurrent_requests_per_domain(); + let domain_permit = if per_domain > 0 { + let domain = req.domain().unwrap_or_default(); + let sem = { + let mut limiters = self.domain_limiters.lock().await; + limiters + .entry(domain) + .or_insert_with(|| Arc::new(Semaphore::new(per_domain as usize))) + .clone() + }; + sem.acquire_owned().await.ok() + } else { + None + }; + self.active_tasks.fetch_add(1, Ordering::SeqCst); let spider = self.spider.clone(); @@ -189,6 +211,7 @@ impl CrawlerEngine { .await; active_tasks.fetch_sub(1, Ordering::SeqCst); + drop(domain_permit); drop(permit); }); } @@ -203,9 +226,13 @@ impl CrawlerEngine { } } - // On pause, save checkpoint + // On pause, save checkpoint — but first wait for in-flight tasks to + // finish so any URLs they enqueue are included in the persisted state. let was_paused = self.paused.load(Ordering::SeqCst); if was_paused { + while self.active_tasks.load(Ordering::SeqCst) > 0 { + tokio::time::sleep(Duration::from_millis(50)).await; + } if let Some(ref cp) = self.checkpoint { let sched = self.scheduler.lock().await; let pending_urls: Vec = sched