diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 0000000..1149c93 --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +# Runs the local CI checks before every push. Enable once with: +# git config core.hooksPath .githooks +# Bypass in an emergency with: git push --no-verify +set -euo pipefail + +exec "$(git rev-parse --show-toplevel)/scripts/ci.sh" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21ec5f0..3cee273 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,7 +48,7 @@ jobs: with: components: clippy - name: Run clippy - run: cargo clippy -- -W clippy::all -D warnings + run: cargo clippy --all-targets -- -W clippy::all -D warnings fmt: name: Format diff --git a/scripts/ci.sh b/scripts/ci.sh new file mode 100755 index 0000000..5e406f8 --- /dev/null +++ b/scripts/ci.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# Local CI: mirrors .github/workflows/ci.yml so the same checks can run +# without GitHub Actions (locally, in a git hook, or any environment). +set -euo pipefail + +cd "$(git rev-parse --show-toplevel)" + +echo "==> cargo fmt -- --check" +cargo fmt -- --check + +echo "==> cargo clippy --all-targets -- -W clippy::all -D warnings" +cargo clippy --all-targets -- -W clippy::all -D warnings + +echo "==> cargo build --verbose" +cargo build --verbose + +echo "==> cargo test --verbose" +cargo test --verbose + +echo "All checks passed." diff --git a/src/core/attributes_handler.rs b/src/core/attributes_handler.rs index 41cd2f6..d3789a2 100644 --- a/src/core/attributes_handler.rs +++ b/src/core/attributes_handler.rs @@ -17,33 +17,63 @@ impl AttributesHandler { Self { inner } } - pub fn get(&self, key: &str) -> Option<&TextHandler> { self.inner.get(key) } - pub fn contains_key(&self, key: &str) -> bool { self.inner.contains_key(key) } - pub fn len(&self) -> usize { self.inner.len() } - pub fn is_empty(&self) -> bool { self.inner.is_empty() } - pub fn keys(&self) -> impl Iterator { self.inner.keys().map(|k| k.as_str()) } - pub fn values(&self) -> impl Iterator { self.inner.values() } + pub fn get(&self, key: &str) -> Option<&TextHandler> { + self.inner.get(key) + } + pub fn contains_key(&self, key: &str) -> bool { + self.inner.contains_key(key) + } + pub fn len(&self) -> usize { + self.inner.len() + } + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + pub fn keys(&self) -> impl Iterator { + self.inner.keys().map(|k| k.as_str()) + } + pub fn values(&self) -> impl Iterator { + self.inner.values() + } pub fn iter(&self) -> impl Iterator { self.inner.iter().map(|(k, v)| (k.as_str(), v)) } /// Search for attributes whose values match a keyword (exact or partial). - pub fn search_values<'a>(&'a self, keyword: &'a str, partial: bool) -> impl Iterator { + pub fn search_values<'a>( + &'a self, + keyword: &'a str, + partial: bool, + ) -> impl Iterator { self.inner.iter().filter_map(move |(k, v)| { - let matches = if partial { v.as_str().contains(keyword) } else { v.as_str() == keyword }; - if matches { Some((k.as_str(), v)) } else { None } + let matches = if partial { + v.as_str().contains(keyword) + } else { + v.as_str() == keyword + }; + if matches { + Some((k.as_str(), v)) + } else { + None + } }) } /// Serialize attributes to JSON string. pub fn json_string(&self) -> String { - let map: IndexMap<&str, &str> = self.inner.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + let map: IndexMap<&str, &str> = self + .inner + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); serde_json::to_string(&map).unwrap_or_default() } } impl std::ops::Index<&str> for AttributesHandler { type Output = TextHandler; - fn index(&self, key: &str) -> &Self::Output { &self.inner[key] } + fn index(&self, key: &str) -> &Self::Output { + &self.inner[key] + } } diff --git a/src/core/mod.rs b/src/core/mod.rs index ef602c4..d9008b1 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,8 +1,8 @@ -pub mod text_handler; -pub mod text_handlers; pub mod attributes_handler; pub mod storage; +pub mod text_handler; +pub mod text_handlers; +pub use attributes_handler::AttributesHandler; pub use text_handler::TextHandler; pub use text_handlers::TextHandlers; -pub use attributes_handler::AttributesHandler; diff --git a/src/core/storage.rs b/src/core/storage.rs index 74fc925..cf51381 100644 --- a/src/core/storage.rs +++ b/src/core/storage.rs @@ -53,9 +53,8 @@ impl SqliteStorage { ) -> Result>, StorageError> { let hash = Self::get_hash(identifier); let conn = self.conn.lock().unwrap(); - let mut stmt = conn.prepare( - "SELECT element_data FROM storage WHERE url = ?1 AND identifier = ?2", - )?; + let mut stmt = + conn.prepare("SELECT element_data FROM storage WHERE url = ?1 AND identifier = ?2")?; let result: Option = stmt .query_row(params![self.url, hash], |row| row.get(0)) .ok(); diff --git a/src/fetchers/client.rs b/src/fetchers/client.rs index 02e1ea4..41398bd 100644 --- a/src/fetchers/client.rs +++ b/src/fetchers/client.rs @@ -1,11 +1,15 @@ use crate::fetchers::config::FetcherConfig; +use crate::fetchers::proxy::ProxyRotator; use crate::fetchers::response::Response; use std::collections::HashMap; use std::time::Duration; pub struct Fetcher { config: FetcherConfig, - client: reqwest::Client, + /// One client per rotating proxy when rotation is enabled, otherwise a + /// single client. Indexed by `rotator` when present. + clients: Vec, + rotator: Option, } #[derive(Debug, thiserror::Error)] @@ -18,6 +22,32 @@ pub enum FetcherError { impl Fetcher { pub fn new(config: FetcherConfig) -> Self { + let (clients, rotator) = if config.proxy_list.is_empty() { + // No rotation: a single client honouring `proxy` and the + // per-protocol `proxies` map. + (vec![Self::build_client(&config, None)], None) + } else { + // Rotation: one client bound to each proxy, selected round-robin. + let clients = config + .proxy_list + .iter() + .map(|p| Self::build_client(&config, Some(p))) + .collect(); + let rotator = ProxyRotator::new(config.proxy_list.clone()); + (clients, rotator) + }; + + Self { + config, + clients, + rotator, + } + } + + /// Build a single reqwest client. When `proxy_override` is `Some`, that + /// proxy is applied for all protocols; otherwise the config's `proxy` and + /// per-protocol `proxies` map are applied. + fn build_client(config: &FetcherConfig, proxy_override: Option<&str>) -> reqwest::Client { let mut builder = reqwest::Client::builder() .timeout(Duration::from_secs(config.timeout_secs)) .danger_accept_invalid_certs(!config.verify_ssl); @@ -32,17 +62,54 @@ impl Fetcher { } // Configure proxy - if let Some(ref proxy_url) = config.proxy { + if let Some(proxy_url) = proxy_override { if let Ok(proxy) = reqwest::Proxy::all(proxy_url) { builder = builder.proxy(proxy); } + } else { + // Apply scheme-specific proxies before any wildcard so the specific + // ones win (reqwest uses the first matching proxy). `proxies` is a + // HashMap, so iterate in a deterministic order. + if let Some(proxy_url) = config.proxies.get("http") { + if let Ok(proxy) = reqwest::Proxy::http(proxy_url) { + builder = builder.proxy(proxy); + } + } + if let Some(proxy_url) = config.proxies.get("https") { + if let Ok(proxy) = reqwest::Proxy::https(proxy_url) { + builder = builder.proxy(proxy); + } + } + let mut wildcard_keys: Vec<&String> = config + .proxies + .keys() + .filter(|k| k.as_str() != "http" && k.as_str() != "https") + .collect(); + wildcard_keys.sort(); + for key in wildcard_keys { + if let Ok(proxy) = reqwest::Proxy::all(&config.proxies[key]) { + builder = builder.proxy(proxy); + } + } + // Single wildcard proxy applied last as a general fallback. + if let Some(ref proxy_url) = config.proxy { + if let Ok(proxy) = reqwest::Proxy::all(proxy_url) { + builder = builder.proxy(proxy); + } + } } - let client = builder - .build() - .expect("Failed to build reqwest client"); + builder.build().expect("Failed to build reqwest client") + } - Self { config, client } + /// Select the client to use for the next request attempt. With rotation + /// enabled this advances the round-robin cursor so a failing proxy is + /// swapped on retry. + fn next_client(&self) -> &reqwest::Client { + match &self.rotator { + Some(rotator) => &self.clients[rotator.next_index()], + None => &self.clients[0], + } } pub async fn get(&self, url: &str) -> Result { @@ -84,7 +151,7 @@ impl Fetcher { let mut last_error = String::new(); for attempt in 0..=self.config.retries { - let mut req = self.client.request(method.clone(), url); + let mut req = self.next_client().request(method.clone(), url); // Set headers for (key, value) in &headers { @@ -130,8 +197,7 @@ impl Fetcher { Err(e) => { last_error = e.to_string(); if attempt < self.config.retries { - tokio::time::sleep(Duration::from_secs(self.config.retry_delay_secs)) - .await; + tokio::time::sleep(Duration::from_secs(self.config.retry_delay_secs)).await; } } } diff --git a/src/fetchers/config.rs b/src/fetchers/config.rs index 8a8924a..8230188 100644 --- a/src/fetchers/config.rs +++ b/src/fetchers/config.rs @@ -12,6 +12,10 @@ pub struct FetcherConfig { pub verify_ssl: bool, pub proxy: Option, pub proxies: HashMap, + /// Proxy URLs to rotate through, one HTTP client is built per entry and + /// selected round-robin per request. Takes precedence over `proxy` / + /// `proxies` when non-empty. + pub proxy_list: Vec, pub headers: HashMap, pub stealthy_headers: bool, pub user_agent: Option, @@ -28,6 +32,7 @@ impl Default for FetcherConfig { verify_ssl: true, proxy: None, proxies: HashMap::new(), + proxy_list: Vec::new(), headers: HashMap::new(), stealthy_headers: true, user_agent: None, @@ -74,11 +79,10 @@ impl FetcherConfig { .entry("accept-encoding".to_string()) .or_insert_with(|| constants::ACCEPT_ENCODING.to_string()); - headers - .entry("sec-ch-ua".to_string()) - .or_insert_with(|| { - "\"Google Chrome\";v=\"131\", \"Chromium\";v=\"131\", \"Not_A Brand\";v=\"24\"".to_string() - }); + headers.entry("sec-ch-ua".to_string()).or_insert_with(|| { + "\"Google Chrome\";v=\"131\", \"Chromium\";v=\"131\", \"Not_A Brand\";v=\"24\"" + .to_string() + }); headers .entry("sec-ch-ua-mobile".to_string()) @@ -147,9 +151,36 @@ impl FetcherConfigBuilder { self } + /// Set a per-protocol proxy override. The scheme is lowercased; `"http"` + /// and `"https"` are routed to their respective protocols, any other key + /// (e.g. `"all"`) applies to all protocols. + pub fn protocol_proxy( + mut self, + scheme: impl Into, + proxy_url: impl Into, + ) -> Self { + self.inner + .proxies + .insert(scheme.into().to_lowercase(), proxy_url.into()); + self + } + + /// Set the list of proxies to rotate through. When non-empty, requests are + /// distributed round-robin across one HTTP client per proxy. + pub fn rotating_proxies(mut self, proxies: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.inner.proxy_list = proxies.into_iter().map(Into::into).collect(); + self + } + /// Add a per-header override. Key is lowercased automatically. pub fn header(mut self, key: impl Into, value: impl Into) -> Self { - self.inner.headers.insert(key.into().to_lowercase(), value.into()); + self.inner + .headers + .insert(key.into().to_lowercase(), value.into()); self } diff --git a/src/fetchers/constants.rs b/src/fetchers/constants.rs index 1b97229..eb248ec 100644 --- a/src/fetchers/constants.rs +++ b/src/fetchers/constants.rs @@ -1,6 +1,14 @@ pub const BLOCKED_RESOURCE_TYPES: &[&str] = &[ - "font", "image", "media", "beacon", "object", "imageset", - "texttrack", "websocket", "csp_report", "stylesheet", + "font", + "image", + "media", + "beacon", + "object", + "imageset", + "texttrack", + "websocket", + "csp_report", + "stylesheet", ]; pub const USER_AGENTS: &[&str] = &[ diff --git a/src/fetchers/mod.rs b/src/fetchers/mod.rs index f05d301..58603ae 100644 --- a/src/fetchers/mod.rs +++ b/src/fetchers/mod.rs @@ -1,5 +1,5 @@ +pub mod client; pub mod config; pub mod constants; -pub mod client; -pub mod response; pub mod proxy; +pub mod response; diff --git a/src/fetchers/proxy.rs b/src/fetchers/proxy.rs index 8fe3f50..46952c8 100644 --- a/src/fetchers/proxy.rs +++ b/src/fetchers/proxy.rs @@ -28,8 +28,14 @@ impl ProxyRotator { /// Return the next proxy in round-robin order. pub fn next(&self) -> &str { - let idx = self.cursor.fetch_add(1, Ordering::Relaxed) % self.proxies.len(); - &self.proxies[idx] + &self.proxies[self.next_index()] + } + + /// Advance the cursor and return the index of the next proxy in + /// round-robin order. Useful for indexing a parallel collection (e.g. a + /// pool of pre-built HTTP clients) that shares the rotator's ordering. + pub fn next_index(&self) -> usize { + self.cursor.fetch_add(1, Ordering::Relaxed) % self.proxies.len() } /// Return a pseudo-random proxy based on the current cursor position. diff --git a/src/lib.rs b/src/lib.rs index 92c8598..333165a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,16 @@ //! RUSTScrapling - A Rust port of the Scrapling web scraping framework. pub mod core; -pub mod parser; pub mod fetchers; +pub mod parser; pub mod spiders; // Re-export primary types at crate root -pub use parser::{Selector, Selectors}; pub use fetchers::client::Fetcher; pub use fetchers::config::FetcherConfig; pub use fetchers::response::Response; -pub use spiders::spider::Spider; +pub use parser::{Selector, Selectors}; +pub use spiders::engine::CrawlerEngine; pub use spiders::request::SpiderRequest; pub use spiders::result::{CrawlResult, CrawlStats, ItemList}; -pub use spiders::engine::CrawlerEngine; +pub use spiders::spider::Spider; diff --git a/src/main.rs b/src/main.rs index 8c8127f..baff1fa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,7 +36,12 @@ async fn main() { let cli = Cli::parse(); match cli.command { - Commands::Fetch { url, selector, format, no_stealth } => { + Commands::Fetch { + url, + selector, + format, + no_stealth, + } => { let config = FetcherConfig::builder().stealth(!no_stealth).build(); let fetcher = Fetcher::new(config); match fetcher.get(&url).await { @@ -64,12 +69,18 @@ async fn main() { "html" => println!("{}", response.text()), _ => { let sel = response.selector(); - println!("{}", sel.get_all_text("\n", true, &["script", "style"], None)); + println!( + "{}", + sel.get_all_text("\n", true, &["script", "style"], None) + ); } } } } - Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); } + Err(e) => { + eprintln!("Error: {}", e); + std::process::exit(1); + } } } Commands::Extract { url, selector } => { @@ -79,12 +90,20 @@ async fn main() { let sel = response.selector(); if let Some(css) = selector { let results = sel.css(&css); - for item in &results { println!("{}", item.text()); } + for item in &results { + println!("{}", item.text()); + } } else { - println!("{}", sel.get_all_text("\n", true, &["script", "style"], None)); + println!( + "{}", + sel.get_all_text("\n", true, &["script", "style"], None) + ); } } - Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); } + Err(e) => { + eprintln!("Error: {}", e); + std::process::exit(1); + } } } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 89043b5..b5b8452 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,6 +1,6 @@ pub mod selector; -pub mod selectors; pub mod selector_generation; +pub mod selectors; pub mod translator; pub use selector::Selector; diff --git a/src/parser/selector.rs b/src/parser/selector.rs index 0c550c9..53da318 100644 --- a/src/parser/selector.rs +++ b/src/parser/selector.rs @@ -96,7 +96,10 @@ impl Selector { /// Return the element's attributes as an AttributesHandler. pub fn attrib(&self) -> AttributesHandler { if let Some(el) = self.element_ref() { - let attrs = el.value().attrs().map(|(k, v)| (k.to_string(), v.to_string())); + let attrs = el + .value() + .attrs() + .map(|(k, v)| (k.to_string(), v.to_string())); AttributesHandler::new(attrs) } else { AttributesHandler::new(std::iter::empty::<(String, String)>()) @@ -171,10 +174,8 @@ impl Selector { parts.push(s.to_string()); } } - Node::Element(ref el) => { - if !ignore_tags.contains(el.name()) { - self.collect_text_recursive(child, ignore_tags, valid_values, parts); - } + Node::Element(ref el) if !ignore_tags.contains(el.name()) => { + self.collect_text_recursive(child, ignore_tags, valid_values, parts); } _ => {} } diff --git a/src/spiders/cache.rs b/src/spiders/cache.rs index fb3e277..6f98ce9 100644 --- a/src/spiders/cache.rs +++ b/src/spiders/cache.rs @@ -30,8 +30,7 @@ impl ResponseCache { pub fn put(&self, url: &str, response: &CachedResponse) -> Result<(), std::io::Error> { let file_path = self.cache_path(url); - let data = serde_json::to_string_pretty(response) - .map_err(std::io::Error::other)?; + let data = serde_json::to_string_pretty(response).map_err(std::io::Error::other)?; std::fs::write(&file_path, data) } diff --git a/src/spiders/checkpoint.rs b/src/spiders/checkpoint.rs index 120d50b..7e1e9d3 100644 --- a/src/spiders/checkpoint.rs +++ b/src/spiders/checkpoint.rs @@ -15,13 +15,14 @@ impl CheckpointManager { pub fn new(dir: &str) -> Result { let path = PathBuf::from(dir); std::fs::create_dir_all(&path)?; - Ok(Self { checkpoint_dir: path }) + Ok(Self { + checkpoint_dir: path, + }) } pub fn save(&self, data: &CheckpointData) -> Result<(), std::io::Error> { let file_path = self.checkpoint_dir.join("checkpoint.json"); - let json = serde_json::to_string_pretty(data) - .map_err(std::io::Error::other)?; + let json = serde_json::to_string_pretty(data).map_err(std::io::Error::other)?; std::fs::write(file_path, json) } diff --git a/src/spiders/engine.rs b/src/spiders/engine.rs index 28e9eaa..6fbab41 100644 --- a/src/spiders/engine.rs +++ b/src/spiders/engine.rs @@ -1,5 +1,5 @@ -use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::Arc; use std::time::Instant; use tokio::sync::{Mutex, Semaphore}; use url::Url; @@ -125,7 +125,10 @@ impl CrawlerEngine { let start_urls = self.spider.start_urls(); let mut domains_seen = std::collections::HashSet::new(); for url in &start_urls { - if let Some(domain) = Url::parse(url).ok().and_then(|u| u.host_str().map(|h| h.to_string())) { + if let Some(domain) = Url::parse(url) + .ok() + .and_then(|u| u.host_str().map(|h| h.to_string())) + { if domains_seen.insert(domain.clone()) { robots.lock().await.fetch_robots(&domain).await; } diff --git a/src/spiders/mod.rs b/src/spiders/mod.rs index e8c2733..5d10652 100644 --- a/src/spiders/mod.rs +++ b/src/spiders/mod.rs @@ -1,10 +1,10 @@ -pub mod spider; +pub mod cache; +pub mod checkpoint; pub mod engine; pub mod request; pub mod response; pub mod result; +pub mod robots; pub mod scheduler; pub mod session; -pub mod checkpoint; -pub mod cache; -pub mod robots; +pub mod spider; diff --git a/src/spiders/request.rs b/src/spiders/request.rs index 67839f8..2452b13 100644 --- a/src/spiders/request.rs +++ b/src/spiders/request.rs @@ -1,6 +1,6 @@ use sha2::{Digest, Sha256}; -use std::collections::HashMap; use std::cmp::Ordering; +use std::collections::HashMap; use url::Url; #[derive(Debug, Clone)] @@ -87,7 +87,9 @@ impl SpiderRequest { } pub fn domain(&self) -> Option { - Url::parse(&self.url).ok().and_then(|u| u.host_str().map(|h| h.to_string())) + Url::parse(&self.url) + .ok() + .and_then(|u| u.host_str().map(|h| h.to_string())) } // Setters @@ -113,7 +115,8 @@ impl SpiderRequest { include_headers: bool, keep_fragments: bool, ) { - self.fingerprint = self.compute_fingerprint(include_kwargs, include_headers, keep_fragments); + self.fingerprint = + self.compute_fingerprint(include_kwargs, include_headers, keep_fragments); } fn compute_fingerprint( diff --git a/src/spiders/result.rs b/src/spiders/result.rs index 4024db3..2315531 100644 --- a/src/spiders/result.rs +++ b/src/spiders/result.rs @@ -30,11 +30,9 @@ impl ItemList { /// Write items as a JSON array to a file. pub fn to_json(&self, path: &Path, indent: usize) -> io::Result<()> { let buf = if indent > 0 { - serde_json::to_vec_pretty(&self.items) - .map_err(io::Error::other)? + serde_json::to_vec_pretty(&self.items).map_err(io::Error::other)? } else { - serde_json::to_vec(&self.items) - .map_err(io::Error::other)? + serde_json::to_vec(&self.items).map_err(io::Error::other)? }; fs::write(path, buf) } @@ -43,8 +41,7 @@ impl ItemList { pub fn to_jsonl(&self, path: &Path) -> io::Result<()> { let mut file = fs::File::create(path)?; for item in &self.items { - let line = serde_json::to_string(item) - .map_err(io::Error::other)?; + let line = serde_json::to_string(item).map_err(io::Error::other)?; writeln!(file, "{}", line)?; } Ok(()) diff --git a/src/spiders/robots.rs b/src/spiders/robots.rs index f24cb9d..02c3f96 100644 --- a/src/spiders/robots.rs +++ b/src/spiders/robots.rs @@ -52,7 +52,10 @@ impl RobotsTxtManager { } pub fn is_allowed(&self, url: &str) -> bool { - let domain = match Url::parse(url).ok().and_then(|u| u.host_str().map(|h| h.to_string())) { + let domain = match Url::parse(url) + .ok() + .and_then(|u| u.host_str().map(|h| h.to_string())) + { Some(d) => d, None => return true, }; @@ -95,7 +98,10 @@ impl RobotsTxtManager { continue; } - if let Some(rest) = line.strip_prefix("User-agent:").or_else(|| line.strip_prefix("user-agent:")) { + if let Some(rest) = line + .strip_prefix("User-agent:") + .or_else(|| line.strip_prefix("user-agent:")) + { let agent = rest.trim().to_lowercase(); if agent == ua_lower || agent == "*" { // Prefer specific match over wildcard @@ -114,12 +120,18 @@ impl RobotsTxtManager { in_matching_section = false; } } else if in_matching_section { - if let Some(path) = line.strip_prefix("Disallow:").or_else(|| line.strip_prefix("disallow:")) { + if let Some(path) = line + .strip_prefix("Disallow:") + .or_else(|| line.strip_prefix("disallow:")) + { let path = path.trim(); if !path.is_empty() { disallowed.push(path.to_string()); } - } else if let Some(delay) = line.strip_prefix("Crawl-delay:").or_else(|| line.strip_prefix("crawl-delay:")) { + } else if let Some(delay) = line + .strip_prefix("Crawl-delay:") + .or_else(|| line.strip_prefix("crawl-delay:")) + { if let Ok(d) = delay.trim().parse::() { crawl_delay = Some(d); } diff --git a/src/spiders/scheduler.rs b/src/spiders/scheduler.rs index 7fa79e7..5a77453 100644 --- a/src/spiders/scheduler.rs +++ b/src/spiders/scheduler.rs @@ -1,5 +1,5 @@ -use std::collections::{BinaryHeap, HashSet}; use super::request::SpiderRequest; +use std::collections::{BinaryHeap, HashSet}; pub struct Scheduler { queue: BinaryHeap, @@ -22,7 +22,11 @@ impl Scheduler { /// Enqueue a request. Returns true if accepted, false if duplicate. pub fn enqueue(&mut self, mut request: SpiderRequest) -> bool { - request.update_fingerprint(self.include_kwargs, self.include_headers, self.keep_fragments); + request.update_fingerprint( + self.include_kwargs, + self.include_headers, + self.keep_fragments, + ); if !request.dont_filter() { let fp = request.fingerprint().to_string(); diff --git a/src/spiders/session.rs b/src/spiders/session.rs index dbbceda..74e3beb 100644 --- a/src/spiders/session.rs +++ b/src/spiders/session.rs @@ -11,7 +11,10 @@ pub struct SessionManager { impl SessionManager { pub fn new(default_config: FetcherConfig) -> Self { - Self { sessions: HashMap::new(), default_config } + Self { + sessions: HashMap::new(), + default_config, + } } pub fn add_session(&mut self, name: &str, config: FetcherConfig) { @@ -20,20 +23,39 @@ impl SessionManager { pub fn ensure_default(&mut self) { if !self.sessions.contains_key("default") { - self.sessions.insert("default".to_string(), Fetcher::new(self.default_config.clone())); + self.sessions.insert( + "default".to_string(), + Fetcher::new(self.default_config.clone()), + ); } } /// Fetch using the session specified in the request (or "default"). pub async fn fetch(&self, request: &SpiderRequest) -> Result { - let session_id = if request.session_id().is_empty() { "default" } else { request.session_id() }; - let fetcher = self.sessions.get(session_id).ok_or_else(|| format!("Session '{}' not found", session_id))?; + let session_id = if request.session_id().is_empty() { + "default" + } else { + request.session_id() + }; + let fetcher = self + .sessions + .get(session_id) + .ok_or_else(|| format!("Session '{}' not found", session_id))?; match request.method() { "GET" => fetcher.get(request.url()).await.map_err(|e| e.to_string()), - "POST" => fetcher.post(request.url(), request.body(), None).await.map_err(|e| e.to_string()), - "PUT" => fetcher.put(request.url(), request.body(), None).await.map_err(|e| e.to_string()), - "DELETE" => fetcher.delete(request.url()).await.map_err(|e| e.to_string()), + "POST" => fetcher + .post(request.url(), request.body(), None) + .await + .map_err(|e| e.to_string()), + "PUT" => fetcher + .put(request.url(), request.body(), None) + .await + .map_err(|e| e.to_string()), + "DELETE" => fetcher + .delete(request.url()) + .await + .map_err(|e| e.to_string()), m => Err(format!("Unsupported HTTP method: {}", m)), } } diff --git a/src/spiders/spider.rs b/src/spiders/spider.rs index 2b6d7d5..af2904d 100644 --- a/src/spiders/spider.rs +++ b/src/spiders/spider.rs @@ -11,30 +11,60 @@ pub trait Spider: Send + Sync + 'static { fn start_urls(&self) -> Vec; // Config with defaults - fn allowed_domains(&self) -> HashSet { HashSet::new() } - fn robots_txt_obey(&self) -> bool { false } - fn concurrent_requests(&self) -> u32 { 4 } - fn concurrent_requests_per_domain(&self) -> u32 { 0 } - fn download_delay(&self) -> f64 { 0.0 } - fn max_blocked_retries(&self) -> u32 { 3 } - fn fp_include_kwargs(&self) -> bool { false } - fn fp_keep_fragments(&self) -> bool { false } - fn fp_include_headers(&self) -> bool { false } - fn development_mode(&self) -> bool { false } + fn allowed_domains(&self) -> HashSet { + HashSet::new() + } + fn robots_txt_obey(&self) -> bool { + false + } + fn concurrent_requests(&self) -> u32 { + 4 + } + fn concurrent_requests_per_domain(&self) -> u32 { + 0 + } + fn download_delay(&self) -> f64 { + 0.0 + } + fn max_blocked_retries(&self) -> u32 { + 3 + } + fn fp_include_kwargs(&self) -> bool { + false + } + fn fp_keep_fragments(&self) -> bool { + false + } + fn fp_include_headers(&self) -> bool { + false + } + fn development_mode(&self) -> bool { + false + } fn start_requests(&self) -> Vec { - self.start_urls().into_iter().map(|url| SpiderRequest::new(&url)).collect() + self.start_urls() + .into_iter() + .map(|url| SpiderRequest::new(&url)) + .collect() } /// Main callback - returns (items, follow_requests) - async fn parse(&self, response: SpiderResponse) -> (Vec, Vec); + async fn parse(&self, response: SpiderResponse) + -> (Vec, Vec); // Hooks with default no-op implementations async fn on_start(&self, _resuming: bool) {} async fn on_close(&self) {} async fn on_error(&self, _request: &SpiderRequest, _error: &str) {} - async fn on_scraped_item(&self, item: serde_json::Value) -> Option { Some(item) } - async fn is_blocked(&self, response: &SpiderResponse) -> bool { response.is_blocked() } + async fn on_scraped_item(&self, item: serde_json::Value) -> Option { + Some(item) + } + async fn is_blocked(&self, response: &SpiderResponse) -> bool { + response.is_blocked() + } - fn fetcher_config(&self) -> FetcherConfig { FetcherConfig::default() } + fn fetcher_config(&self) -> FetcherConfig { + FetcherConfig::default() + } } diff --git a/tests/core_attributes_handler.rs b/tests/core_attributes_handler.rs index b6151d1..3d11fd5 100644 --- a/tests/core_attributes_handler.rs +++ b/tests/core_attributes_handler.rs @@ -64,7 +64,10 @@ fn test_keys_iteration() { fn test_values_iteration() { let attrs = make_attrs(); let values: Vec<&str> = attrs.values().map(|v| v.as_str()).collect(); - assert_eq!(values, vec!["btn primary", "submit-btn", "https://example.com", "42"]); + assert_eq!( + values, + vec!["btn primary", "submit-btn", "https://example.com", "42"] + ); } #[test] diff --git a/tests/core_storage.rs b/tests/core_storage.rs index d4621c1..a5ed969 100644 --- a/tests/core_storage.rs +++ b/tests/core_storage.rs @@ -20,7 +20,10 @@ fn test_save_and_retrieve_roundtrip() { let result = storage.retrieve("elem1").unwrap(); assert!(result.is_some()); let retrieved = result.unwrap(); - assert_eq!(retrieved.get("tag"), Some(&serde_json::Value::String("div".to_string()))); + assert_eq!( + retrieved.get("tag"), + Some(&serde_json::Value::String("div".to_string())) + ); } #[test] @@ -46,7 +49,10 @@ fn test_update_existing_second_value_wins() { storage.save("elem1", &second).unwrap(); let result = storage.retrieve("elem1").unwrap().unwrap(); - assert_eq!(result.get("class"), Some(&serde_json::Value::String("new-class".to_string()))); + assert_eq!( + result.get("class"), + Some(&serde_json::Value::String("new-class".to_string())) + ); } #[test] @@ -67,7 +73,10 @@ fn test_different_urls_isolate_data() { // site-a should still see its own data let result_a = storage_a.retrieve("elem1").unwrap(); assert!(result_a.is_some()); - assert_eq!(result_a.unwrap().get("src"), Some(&serde_json::Value::String("site-a-value".to_string()))); + assert_eq!( + result_a.unwrap().get("src"), + Some(&serde_json::Value::String("site-a-value".to_string())) + ); } #[test] @@ -84,6 +93,12 @@ fn test_different_identifiers_dont_collide() { let result1 = storage.retrieve("identifier_one").unwrap().unwrap(); let result2 = storage.retrieve("identifier_two").unwrap().unwrap(); - assert_eq!(result1.get("id"), Some(&serde_json::Value::String("first".to_string()))); - assert_eq!(result2.get("id"), Some(&serde_json::Value::String("second".to_string()))); + assert_eq!( + result1.get("id"), + Some(&serde_json::Value::String("first".to_string())) + ); + assert_eq!( + result2.get("id"), + Some(&serde_json::Value::String("second".to_string())) + ); } diff --git a/tests/fetchers_client.rs b/tests/fetchers_client.rs index 870380d..68e457e 100644 --- a/tests/fetchers_client.rs +++ b/tests/fetchers_client.rs @@ -137,7 +137,9 @@ async fn test_fetcher_get() { async fn test_fetcher_post_json() { let fetcher = Fetcher::new(FetcherConfig::default()); let body = serde_json::json!({"key": "value"}); - let response = fetcher.post("https://httpbin.org/post", None, Some(&body)).await; + let response = fetcher + .post("https://httpbin.org/post", None, Some(&body)) + .await; assert!(response.is_ok()); let resp = response.unwrap(); assert_eq!(resp.status(), 200); diff --git a/tests/fetchers_config.rs b/tests/fetchers_config.rs index ab97f1f..ad496a7 100644 --- a/tests/fetchers_config.rs +++ b/tests/fetchers_config.rs @@ -47,7 +47,10 @@ fn builder_overrides_defaults() { assert!(!cfg.stealthy_headers); assert!(!cfg.follow_redirects); assert!(!cfg.verify_ssl); - assert_eq!(cfg.headers.get("x-custom").map(|s| s.as_str()), Some("value")); + assert_eq!( + cfg.headers.get("x-custom").map(|s| s.as_str()), + Some("value") + ); } #[test] @@ -68,20 +71,40 @@ fn stealth_headers_contain_required_fields() { let cfg = FetcherConfig::default(); let headers = cfg.build_headers("https://example.com", true); - assert!(headers.contains_key("user-agent"), "user-agent must be present"); + assert!( + headers.contains_key("user-agent"), + "user-agent must be present" + ); assert!(headers.contains_key("accept"), "accept must be present"); - assert!(headers.contains_key("accept-language"), "accept-language must be present"); - assert!(headers.contains_key("accept-encoding"), "accept-encoding must be present"); - assert!(headers.contains_key("sec-fetch-dest"), "sec-fetch-dest must be present"); - assert!(headers.contains_key("sec-fetch-mode"), "sec-fetch-mode must be present"); - assert!(headers.contains_key("sec-fetch-site"), "sec-fetch-site must be present"); + assert!( + headers.contains_key("accept-language"), + "accept-language must be present" + ); + assert!( + headers.contains_key("accept-encoding"), + "accept-encoding must be present" + ); + assert!( + headers.contains_key("sec-fetch-dest"), + "sec-fetch-dest must be present" + ); + assert!( + headers.contains_key("sec-fetch-mode"), + "sec-fetch-mode must be present" + ); + assert!( + headers.contains_key("sec-fetch-site"), + "sec-fetch-site must be present" + ); } #[test] fn stealth_user_agent_is_non_empty() { let cfg = FetcherConfig::default(); let headers = cfg.build_headers("https://example.com", true); - let ua = headers.get("user-agent").expect("user-agent header missing"); + let ua = headers + .get("user-agent") + .expect("user-agent header missing"); assert!(!ua.is_empty()); } @@ -156,7 +179,9 @@ fn proxy_rotator_round_robin() { #[test] fn proxy_rotator_random_stays_in_bounds() { - let proxies: Vec = (0..5).map(|i| format!("http://proxy{}.example.com", i)).collect(); + let proxies: Vec = (0..5) + .map(|i| format!("http://proxy{}.example.com", i)) + .collect(); let rotator = ProxyRotator::new(proxies).expect("should create rotator"); for _ in 0..20 { @@ -167,8 +192,61 @@ fn proxy_rotator_random_stays_in_bounds() { #[test] fn proxy_rotator_len() { - let proxies = vec!["http://a.example.com".to_string(), "http://b.example.com".to_string()]; + let proxies = vec![ + "http://a.example.com".to_string(), + "http://b.example.com".to_string(), + ]; let rotator = ProxyRotator::new(proxies).unwrap(); assert_eq!(rotator.len(), 2); assert!(!rotator.is_empty()); } + +#[test] +fn proxy_rotator_next_index_round_robin() { + let proxies = vec![ + "http://a.example.com".to_string(), + "http://b.example.com".to_string(), + "http://c.example.com".to_string(), + ]; + let rotator = ProxyRotator::new(proxies).unwrap(); + assert_eq!(rotator.next_index(), 0); + assert_eq!(rotator.next_index(), 1); + assert_eq!(rotator.next_index(), 2); + assert_eq!(rotator.next_index(), 0); +} + +// --------------------------------------------------------------------------- +// FetcherConfig – proxy configuration +// --------------------------------------------------------------------------- + +#[test] +fn builder_protocol_proxy_populates_map() { + let cfg = FetcherConfig::builder() + .protocol_proxy("HTTP", "http://p1.example.com") + .protocol_proxy("https", "http://p2.example.com") + .build(); + assert_eq!( + cfg.proxies.get("http").map(|s| s.as_str()), + Some("http://p1.example.com") + ); + assert_eq!( + cfg.proxies.get("https").map(|s| s.as_str()), + Some("http://p2.example.com") + ); +} + +#[test] +fn builder_rotating_proxies_populates_list() { + let cfg = FetcherConfig::builder() + .rotating_proxies(vec!["http://p1.example.com", "http://p2.example.com"]) + .build(); + assert_eq!( + cfg.proxy_list, + vec!["http://p1.example.com", "http://p2.example.com"] + ); +} + +#[test] +fn default_proxy_list_is_empty() { + assert!(FetcherConfig::default().proxy_list.is_empty()); +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index c31f28d..92b3b0f 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,6 +1,6 @@ -use rust_scrapling::parser::Selector; use rust_scrapling::core::text_handler::TextHandler; use rust_scrapling::fetchers::response::Response; +use rust_scrapling::parser::Selector; use std::collections::HashMap; const ECOMMERCE_HTML: &str = r#" @@ -254,7 +254,11 @@ fn test_response_struct_integration() { fn test_selectors_getall() { let sel = Selector::from_html(ECOMMERCE_HTML); let names = sel.css("h2.product-name"); - let all_text: Vec = names.getall().iter().map(|t| t.as_str().to_string()).collect(); + let all_text: Vec = names + .getall() + .iter() + .map(|t| t.as_str().to_string()) + .collect(); assert_eq!(all_text, vec!["Laptop Pro", "Wireless Mouse", "USB-C Hub"]); } diff --git a/tests/parser_selector.rs b/tests/parser_selector.rs index 3cae9e5..03e1782 100644 --- a/tests/parser_selector.rs +++ b/tests/parser_selector.rs @@ -291,8 +291,7 @@ fn test_selectors_get_first() { assert_eq!(first.unwrap().as_str(), "Item 1"); let empty = sel.css("div.nonexistent"); - let fallback = - empty.get_first(Some(rust_scrapling::core::TextHandler::new("default"))); + let fallback = empty.get_first(Some(rust_scrapling::core::TextHandler::new("default"))); assert_eq!(fallback.unwrap().as_str(), "default"); } diff --git a/tests/parser_selector_generation.rs b/tests/parser_selector_generation.rs index 9bf658b..efd2bc2 100644 --- a/tests/parser_selector_generation.rs +++ b/tests/parser_selector_generation.rs @@ -1,5 +1,5 @@ -use rust_scrapling::parser::Selector; use rust_scrapling::parser::selector_generation::{generate_css_selector, generate_xpath_selector}; +use rust_scrapling::parser::Selector; const HTML: &str = r#"
  • First
  • Second

Nested

"#; @@ -19,7 +19,11 @@ fn test_css_selector_for_li_includes_nth_of_type() { assert!(li.len() >= 2); // The second li should get nth-of-type since there are multiple siblings let css = generate_css_selector(&li[1], true); - assert!(css.contains("nth-of-type"), "Expected 'nth-of-type' in: {}", css); + assert!( + css.contains("nth-of-type"), + "Expected 'nth-of-type' in: {}", + css + ); } #[test] @@ -28,7 +32,11 @@ fn test_xpath_selector_with_id() { let div = sel.css("#main"); assert_eq!(div.len(), 1); let xpath = generate_xpath_selector(&div[0], false); - assert!(xpath.contains("@id='main'"), "Expected \"@id='main'\" in: {}", xpath); + assert!( + xpath.contains("@id='main'"), + "Expected \"@id='main'\" in: {}", + xpath + ); } #[test] @@ -37,5 +45,9 @@ fn test_full_css_selector_includes_body() { let div = sel.css("#main"); assert_eq!(div.len(), 1); let css = generate_css_selector(&div[0], true); - assert!(css.contains("body"), "Expected 'body' in full path: {}", css); + assert!( + css.contains("body"), + "Expected 'body' in full path: {}", + css + ); } diff --git a/tests/spiders_request.rs b/tests/spiders_request.rs index 23a59c1..0d01856 100644 --- a/tests/spiders_request.rs +++ b/tests/spiders_request.rs @@ -68,8 +68,12 @@ fn test_domain() { #[test] fn test_ordering() { - let low = SpiderRequest::builder("https://example.com/low").priority(1).build(); - let high = SpiderRequest::builder("https://example.com/high").priority(10).build(); + let low = SpiderRequest::builder("https://example.com/low") + .priority(1) + .build(); + let high = SpiderRequest::builder("https://example.com/high") + .priority(10) + .build(); assert!(high > low); } diff --git a/tests/spiders_result.rs b/tests/spiders_result.rs index 5232ff9..f54fc67 100644 --- a/tests/spiders_result.rs +++ b/tests/spiders_result.rs @@ -1,4 +1,4 @@ -use rust_scrapling::spiders::result::{CrawlStats, CrawlResult, ItemList}; +use rust_scrapling::spiders::result::{CrawlResult, CrawlStats, ItemList}; use std::time::Instant; #[test] diff --git a/tests/spiders_scheduler.rs b/tests/spiders_scheduler.rs index d21b16a..c767064 100644 --- a/tests/spiders_scheduler.rs +++ b/tests/spiders_scheduler.rs @@ -38,9 +38,15 @@ fn test_dont_filter_bypasses_dedup() { #[test] fn test_priority_ordering() { let mut sched = Scheduler::new(false, false, false); - let low = SpiderRequest::builder("https://example.com/low").priority(1).build(); - let high = SpiderRequest::builder("https://example.com/high").priority(10).build(); - let mid = SpiderRequest::builder("https://example.com/mid").priority(5).build(); + let low = SpiderRequest::builder("https://example.com/low") + .priority(1) + .build(); + let high = SpiderRequest::builder("https://example.com/high") + .priority(10) + .build(); + let mid = SpiderRequest::builder("https://example.com/mid") + .priority(5) + .build(); sched.enqueue(low); sched.enqueue(high);