diff --git a/Cargo.lock b/Cargo.lock index 81eb49753..65d5db8a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -884,7 +884,7 @@ dependencies = [ "async-trait", "bdk-macros", "bitcoin 0.29.2", - "electrum-client 0.12.1", + "electrum-client", "getrandom 0.2.17", "js-sys", "log", @@ -931,16 +931,6 @@ dependencies = [ "serde", ] -[[package]] -name = "bdk_electrum" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b59a3f7fbe678874fa34354097644a171276e02a49934c13b3d61c54610ddf39" -dependencies = [ - "bdk_core", - "electrum-client 0.24.1", -] - [[package]] name = "bdk_wallet" version = "2.3.0" @@ -1110,12 +1100,12 @@ dependencies = [ "bdk", "bdk_chain", "bdk_core", - "bdk_electrum", "bdk_wallet", "bitcoin 0.32.8", "bitcoin-harness", "derive_builder", "electrum-pool", + "electrum_streaming_client", "futures", "moka", "proptest", @@ -3036,33 +3026,34 @@ dependencies = [ ] [[package]] -name = "electrum-client" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5059f13888a90486e7268bbce59b175f5f76b1c55e5b9c568ceaa42d2b8507c" +name = "electrum-pool" +version = "0.1.0" dependencies = [ + "anyhow", + "backoff", "bitcoin 0.32.8", - "byteorder", - "libc", - "log", - "rustls 0.23.37", - "serde", + "electrum_streaming_client", + "futures", + "once_cell", + "rustls-native-certs 0.8.3", "serde_json", - "webpki-roots 0.25.4", - "winapi", + "thiserror 1.0.69", + "tokio", + "tokio-rustls 0.26.4", + "tracing", ] [[package]] -name = "electrum-pool" -version = "0.1.0" +name = "electrum_streaming_client" +version = "0.4.0" +source = "git+https://github.com/bitcoindevkit/electrum_streaming_client?rev=ed94df0ae21be0f892415872368467872d7bac63#ed94df0ae21be0f892415872368467872d7bac63" dependencies = [ - "backoff", - "bdk_electrum", "bitcoin 0.32.8", "futures", - "once_cell", + "serde", + "serde_json", "tokio", - "tracing", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d28aa725c..73e315bde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ bdk_core = "0.6.0" bdk_electrum = { version = "0.23.0", default-features = false } bdk_wallet = "2.0.0" bitcoin = { version = "0.32", features = ["rand", "serde"] } +electrum_streaming_client = { git = "https://github.com/bitcoindevkit/electrum_streaming_client", rev = "ed94df0ae21be0f892415872368467872d7bac63" } # monero-oxide monero-address = { git = "https://github.com/kayabaNerve/monero-oxide.git" } @@ -95,6 +96,10 @@ testcontainers = "0.15" tokio = { version = "1", features = ["rt-multi-thread", "time", "macros", "sync"] } tokio-util = { version = "0.7", features = ["io", "codec", "rt"] } +# Electrum transport TLS +tokio-rustls = { version = "0.26", default-features = false, features = ["ring", "tls12", "logging"] } +rustls-native-certs = "0.8" + # Tor/Arti crates arti-client = { git = "https://github.com/eigenwallet/arti", branch = "downgraded_rusqlite_arti_2_2_0", default-features = false } libp2p-tor = { path = "./libp2p-tor" } diff --git a/bitcoin-wallet/Cargo.toml b/bitcoin-wallet/Cargo.toml index f79793a75..ba446e690 100644 --- a/bitcoin-wallet/Cargo.toml +++ b/bitcoin-wallet/Cargo.toml @@ -10,11 +10,11 @@ backoff = { workspace = true } bdk = { workspace = true } bdk_chain = { workspace = true } bdk_core = { workspace = true } -bdk_electrum = { workspace = true, features = ["use-rustls-ring"] } bdk_wallet = { workspace = true, features = ["rusqlite", "test-utils"] } bitcoin = { workspace = true } derive_builder = "0.20.2" electrum-pool = { path = "../electrum-pool" } +electrum_streaming_client = { workspace = true } futures = { workspace = true } moka = { version = "0.12", features = ["sync", "future"] } proptest = "1" diff --git a/bitcoin-wallet/src/core.rs b/bitcoin-wallet/src/core.rs index 7164ac6a0..4baf09c9a 100644 --- a/bitcoin-wallet/src/core.rs +++ b/bitcoin-wallet/src/core.rs @@ -1,6 +1,4 @@ -use anyhow::Context; use anyhow::bail; -use bdk_electrum::electrum_client::HeaderNotification; use serde::{Deserialize, Serialize}; use std::ops::Add; @@ -25,19 +23,6 @@ impl From for BlockHeight { } } -impl TryFrom for BlockHeight { - type Error = anyhow::Error; - - fn try_from(value: HeaderNotification) -> Result { - Ok(Self( - value - .height - .try_into() - .context("Failed to fit usize into u32")?, - )) - } -} - impl Add for BlockHeight { type Output = BlockHeight; fn add(self, rhs: u32) -> Self::Output { @@ -171,85 +156,68 @@ impl From for i64 { } } -pub fn parse_rpc_error_code(error: &anyhow::Error) -> anyhow::Result { - // First try to extract an Electrum error from a MultiError if present - for error in error.chain() { - if let Some(multi_error) = error.downcast_ref::() { - // Try to find the first Electrum error in the MultiError - for single_error in multi_error.iter() { - if let bdk_electrum::electrum_client::Error::Protocol(serde_json::Value::String( - string, - )) = single_error - { - let json = serde_json::from_str( - &string - .replace("sendrawtransaction RPC error:", "") - .replace("daemon error:", ""), - )?; +/// Extract a Bitcoin Core RPC error code from a server error JSON payload. +/// +/// The payload may be the error object itself (`{ "code": -26, ... }`) or wrap the relevant code +/// inside a `message` string (e.g. `sendrawtransaction RPC error: { "code": -26, ... }`). +pub(crate) fn extract_rpc_error_code(payload: &str) -> Option { + fn code_from_value(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::Object(map) => map.get("code").and_then(serde_json::Value::as_i64), + serde_json::Value::String(string) => code_from_str(string), + _ => None, + } + } - let json_map = match json { - serde_json::Value::Object(map) => map, - _ => continue, // Try next error if this one isn't a JSON object - }; + fn code_from_str(raw: &str) -> Option { + let cleaned = raw + .replace("sendrawtransaction RPC error:", "") + .replace("daemon error:", ""); + let value: serde_json::Value = serde_json::from_str(cleaned.trim()).ok()?; + code_from_value(&value) + } - let error_code_value = match json_map.get("code") { - Some(val) => val, - None => continue, // Try next error if no error code field - }; + let value: serde_json::Value = match serde_json::from_str(payload) { + Ok(value) => value, + // The payload was not valid JSON on its own; treat it as a raw (possibly prefixed) string. + Err(_) => return code_from_str(payload), + }; - let error_code_number = match error_code_value { - serde_json::Value::Number(num) => num, - _ => continue, // Try next error if error code isn't a number - }; + // A direct code, or one nested inside the `message` field of the error object. + code_from_value(&value).or_else(|| { + value + .get("message") + .and_then(serde_json::Value::as_str) + .and_then(code_from_str) + }) +} - if let Some(int) = error_code_number.as_i64() { - return Ok(int); +pub fn parse_rpc_error_code(error: &anyhow::Error) -> anyhow::Result { + for error in error.chain() { + if let Some(multi_error) = error.downcast_ref::() { + for single_error in multi_error.iter() { + if let Some(json) = single_error.response_json() { + if let Some(code) = extract_rpc_error_code(json) { + return Ok(code); } } } - // If we couldn't extract an RPC error code from any error in the MultiError bail!( - "Error is of incorrect variant. We expected an Electrum error, but got: {}", + "Error is of incorrect variant. We expected an Electrum server error, but got: {}", error ); } - // Original logic for direct Electrum errors - let string = match error.downcast_ref::() { - Some(bdk_electrum::electrum_client::Error::Protocol(serde_json::Value::String( - string, - ))) => string, - _ => bail!( - "Error is of incorrect variant. We expected an Electrum error, but got: {}", + if let Some(single_error) = error.downcast_ref::() { + if let Some(json) = single_error.response_json() { + if let Some(code) = extract_rpc_error_code(json) { + return Ok(code); + } + } + bail!( + "Error is of incorrect variant. We expected an Electrum server error, but got: {}", error - ), - }; - - let json = serde_json::from_str( - &string - .replace("sendrawtransaction RPC error:", "") - .replace("daemon error:", ""), - )?; - - let json_map = match json { - serde_json::Value::Object(map) => map, - _ => bail!("Json error is not json object "), - }; - - let error_code_value = match json_map.get("code") { - Some(val) => val, - None => bail!("No error code field"), - }; - - let error_code_number = match error_code_value { - serde_json::Value::Number(num) => num, - _ => bail!("Error code is not a number"), - }; - - if let Some(int) = error_code_number.as_i64() { - return Ok(int); - } else { - bail!("Error code is not an unsigned integer") + ); } } diff --git a/bitcoin-wallet/src/electrum.rs b/bitcoin-wallet/src/electrum.rs new file mode 100644 index 000000000..cd1320c1e --- /dev/null +++ b/bitcoin-wallet/src/electrum.rs @@ -0,0 +1,1113 @@ +//! Electrum backend for the Bitcoin wallet, built on [`electrum_streaming_client`] via the async +//! [`ElectrumBalancer`]. +//! +//! [`Client`] mirrors the previous `bdk_electrum`-backed client: it tracks watched script +//! histories and the chain tip (for [`Watchable`] status), broadcasts to every server, and +//! estimates fees. [`SyncGlue`] re-ports the `bdk_electrum` chain-sync logic (full scan / sync, +//! transaction & anchor caching, Merkle-proof-validated confirmation anchors and re-org-aware +//! checkpoint construction) so that the produced [`FullScanResponse`]/[`SyncResponse`] can be +//! applied to a `bdk_wallet::Wallet` exactly as before. + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::{Arc, Mutex as SyncMutex}; +use std::time::{Duration, Instant}; + +use anyhow::{Context, Result, anyhow}; +use bdk_core::spk_client::{ + FullScanRequest, FullScanResponse, SpkWithExpectedTxids, SyncRequest, SyncResponse, +}; +use bdk_core::{BlockId, CheckPoint, ConfirmationBlockTime, TxUpdate}; +use bdk_wallet::KeychainKind; +use bitcoin::{BlockHash, FeeRate, OutPoint, ScriptBuf, Transaction, Txid, block::Header}; +use electrum_pool::{Connection, ElectrumBalancer, Error}; +use electrum_streaming_client::request::{ + EstimateFee, GetFeeHistogram, GetHistory, GetTx, GetTxMerkle, Header as HeaderReq, Headers, + HeadersSubscribe, RelayFee, +}; +use electrum_streaming_client::response; +use tokio::sync::{Mutex as TokioMutex, RwLock as TokioRwLock}; + +use crate::primitives::{Confirmed, EstimateFeeRate, ScriptStatus, Watchable}; +use crate::{BlockHeight, RpcErrorCode, extract_rpc_error_code}; + +/// We include a chain suffix of a certain length for the purpose of robustness. +const CHAIN_SUFFIX_LENGTH: u32 = 8; + +/// One Electrum history entry for a script: a transaction id and its Electrum height +/// (`> 0` confirmed at that block height, `0`/`-1` unconfirmed). +#[derive(Debug, Clone, Copy)] +struct HistoryEntry { + txid: Txid, + height: i64, +} + +impl From<&response::Tx> for HistoryEntry { + fn from(tx: &response::Tx) -> Self { + Self { + txid: tx.txid(), + height: tx.electrum_height(), + } + } +} + +/// In-memory caches shared across all server connections, mirroring `BdkElectrumClient`. +#[derive(Default)] +pub(crate) struct Caches { + txs: SyncMutex>>, + headers: SyncMutex>, + anchors: SyncMutex>, +} + +/// Electrum client wrapping the load balancer plus watched-script state. +#[derive(Clone)] +pub struct Client { + /// The underlying load balancer over all configured Electrum servers. + pub(crate) inner: Arc, + /// Transaction/header/anchor caches used by the chain-sync glue. + pub(crate) caches: Arc, + /// Last-known merged history for each watched script. + script_history: Arc>>>, + /// Active subscriptions, deduplicated by `(txid, script)`. + pub(crate) subscriptions: Arc>>, + /// Time of the last `update_state`. + last_sync: Arc>, + /// How often `update_state` actually refreshes. + sync_interval: Duration, + /// Monotonic latest known block height. + latest_block_height: Arc>, +} + +impl Client { + /// Create a new client over the given Electrum servers. + pub async fn new(electrum_rpc_urls: &[String], sync_interval: Duration) -> Result { + let balancer = ElectrumBalancer::new(electrum_rpc_urls.to_vec()) + .map_err(|e| anyhow!("Failed to create Electrum balancer: {e}"))?; + let initial_last_sync = Instant::now() + .checked_sub(sync_interval) + .ok_or_else(|| anyhow!("failed to set last sync time"))?; + + Ok(Self { + inner: Arc::new(balancer), + caches: Arc::new(Caches::default()), + script_history: Arc::new(TokioRwLock::new(BTreeMap::new())), + subscriptions: Arc::new(TokioMutex::new(HashMap::new())), + last_sync: Arc::new(SyncMutex::new(initial_last_sync)), + sync_interval, + latest_block_height: Arc::new(SyncMutex::new(BlockHeight::from(0))), + }) + } + + /// Refresh watched-script histories and the chain tip if the sync interval has elapsed (or + /// `force`). + pub async fn update_state(&self, force: bool) -> Result<()> { + if !force { + let last_sync = *self.last_sync.lock().expect("last_sync mutex poisoned"); + if Instant::now().duration_since(last_sync) < self.sync_interval { + return Ok(()); + } + } + + self.update_script_histories().await?; + self.update_block_height().await?; + + *self.last_sync.lock().expect("last_sync mutex poisoned") = Instant::now(); + + Ok(()) + } + + /// Refresh a single script's history and the chain tip, ignoring the sync-interval throttle. + pub async fn update_state_single(&self, script: &dyn Watchable) -> Result<()> { + self.update_script_history_for(script.script()).await?; + self.update_block_height().await?; + Ok(()) + } + + async fn update_block_height(&self) -> Result<()> { + let latest = self + .inner + .request("block_headers_subscribe", HeadersSubscribe) + .await + .context("Failed to fetch latest block header")?; + let latest_block_height = BlockHeight::from(latest.height); + + let mut current = self + .latest_block_height + .lock() + .expect("latest_block_height mutex poisoned"); + if latest_block_height > *current { + tracing::trace!( + block_height = u32::from(latest_block_height), + "Got notification for new block" + ); + *current = latest_block_height; + } + + Ok(()) + } + + async fn update_script_histories(&self) -> Result<()> { + let scripts: Vec = self.script_history.read().await.keys().cloned().collect(); + if scripts.is_empty() { + return Ok(()); + } + + let mut any_success = false; + let mut last_error = None; + for script in scripts { + match self.update_script_history_for(script).await { + Ok(()) => any_success = true, + Err(e) => last_error = Some(e), + } + } + + if !any_success { + if let Some(e) = last_error { + return Err(e); + } + } + + Ok(()) + } + + /// Refresh a single script's history by merging the responses of all servers (highest height + /// wins per txid). Succeeds if at least one server responds. + pub async fn update_script_history(&self, script: &dyn Watchable) -> Result<()> { + self.update_script_history_for(script.script()).await + } + + async fn update_script_history_for(&self, script: ScriptBuf) -> Result<()> { + let results = self.inner.script_get_history_all(script.clone()).await; + + let mut all_entries = Vec::new(); + let mut any_success = false; + let mut first_error = None; + for result in results { + match result { + Ok(history) => { + any_success = true; + all_entries.extend(history.iter().map(HistoryEntry::from)); + } + Err(e) => { + if first_error.is_none() { + first_error = Some(e); + } + } + } + } + + if !any_success { + if let Some(e) = first_error { + return Err(anyhow::Error::new(e)); + } + } + + self.script_history + .write() + .await + .insert(script, merge_history(all_entries)); + + Ok(()) + } + + /// Broadcast a transaction to all servers in parallel, caching it on first acceptance. + pub async fn transaction_broadcast_all( + &self, + transaction: &Transaction, + ) -> Result>> { + let results = self.inner.broadcast_all(transaction.clone()).await; + + if results.iter().any(|r| r.is_ok()) { + self.caches + .txs + .lock() + .expect("tx cache poisoned") + .insert(transaction.compute_txid(), Arc::new(transaction.clone())); + } + + Ok(results) + } + + /// Compute the [`ScriptStatus`] of the given watchable transaction. + pub async fn status_of_script( + &self, + script: &dyn Watchable, + force: bool, + ) -> Result { + let (script_buf, txid) = script.script_and_txid(); + + let is_first_time = { + let mut history = self.script_history.write().await; + if history.contains_key(&script_buf) { + false + } else { + history.insert(script_buf.clone(), vec![]); + true + } + }; + + if is_first_time || force { + self.update_state_single(script).await?; + } else { + self.update_state(false).await?; + } + + let history_guard = self.script_history.read().await; + let history = history_guard.get(&script_buf); + + let history_of_tx: Vec<&HistoryEntry> = history + .into_iter() + .flatten() + .filter(|entry| entry.txid == txid) + .collect(); + + let [rest @ .., last] = history_of_tx.as_slice() else { + return Ok(ScriptStatus::Unseen); + }; + + if !rest.is_empty() { + tracing::warn!(%txid, "Found multiple history entries for the same txid. Ignoring all but the last one."); + } + + let latest_block_height = *self + .latest_block_height + .lock() + .expect("latest_block_height mutex poisoned"); + + match last.height { + ..=0 => Ok(ScriptStatus::InMempool), + height => Ok(ScriptStatus::Confirmed( + Confirmed::from_inclusion_and_latest_block( + u32::try_from(height)?, + u32::from(latest_block_height), + ), + )), + } + } + + /// Fetch a transaction from any server. `Ok(None)` if the servers report it does not exist. + pub async fn get_tx(&self, txid: Txid) -> Result>> { + match self.inner.request("get_raw_transaction", GetTx { txid }).await { + Ok(full) => { + let tx = Arc::new(full.tx); + self.caches + .txs + .lock() + .expect("tx cache poisoned") + .insert(txid, tx.clone()); + Ok(Some(tx)) + } + Err(multi_error) => { + if multi_error.any(is_tx_not_found) { + tracing::trace!( + %txid, + error_count = multi_error.len(), + "Transaction not found indicated by one or more Electrum servers" + ); + Ok(None) + } else { + Err(anyhow!(multi_error) + .context("Failed to get transaction from the Electrum server")) + } + } + } + } + + /// Estimate the fee rate (sat/kwu) to be confirmed within `target_block` blocks via + /// `blockchain.estimatefee`. + pub async fn estimate_fee_rate(&self, target_block: u32) -> Result { + let resp = self + .inner + .request( + "estimate_fee", + EstimateFee { + number: target_block as usize, + }, + ) + .await?; + + resp.fee_rate + .filter(|rate| rate.to_sat_per_kwu() > 0) + .ok_or_else(|| anyhow!("Fee rate returned by Electrum server is less than 0")) + } + + /// Estimate a fee rate from the mempool fee histogram, adapting faster to mempool spikes. + async fn estimate_fee_rate_from_histogram(&self, target_block: u32) -> Result { + const HISTOGRAM_SAFETY_MARGIN: f32 = 0.8; + + let mut histogram = self + .inner + .request("get_fee_histogram", GetFeeHistogram) + .await?; + + if histogram.is_empty() { + return Err(anyhow!( + "The mempool seems to be empty therefore we cannot estimate the fee rate from the histogram" + )); + } + + histogram.sort_by(|a, b| a.fee_rate.cmp(&b.fee_rate)); + + let estimated_block_size = 1_000_000u64; + #[allow(clippy::cast_precision_loss)] + let target_distance_from_tip = + (estimated_block_size * target_block as u64) as f32 * HISTOGRAM_SAFETY_MARGIN; + + let mut cumulative_vsize = 0u64; + for pair in &histogram { + cumulative_vsize += pair.weight.to_vbytes_ceil(); + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + if cumulative_vsize >= target_distance_from_tip as u64 { + return Ok(pair.fee_rate); + } + } + + Ok(histogram + .first() + .expect("The histogram should not be empty") + .fee_rate) + } + + async fn min_relay_fee(&self) -> Result { + let resp = self.inner.request("relay_fee", RelayFee).await?; + + // The relay fee is reported per kvB; convert to sat / kwu (kwu = kB × 4). + let sat_per_kwu = resp.fee.to_sat() / 4; + Ok(FeeRate::from_sat_per_kwu(sat_per_kwu)) + } + + /// Full scan a `bdk_wallet` full-scan request against a single server with failover. + pub(crate) async fn full_scan( + &self, + build_request: F, + stop_gap: usize, + batch_size: usize, + ) -> Result> + where + F: Fn() -> FullScanRequest + Send + Sync, + { + let build_request = &build_request; + let caches = self.caches.clone(); + let response = self + .inner + .run("full_scan_wallet", move |conn| { + let request = build_request(); + let glue = SyncGlue::new(conn, caches.clone()); + async move { glue.full_scan(request, stop_gap, batch_size, true).await } + }) + .await?; + + Ok(response) + } + + /// Sync a `bdk_wallet` sync request against a single server with failover. + pub(crate) async fn sync(&self, build_request: F, batch_size: usize) -> Result + where + F: Fn() -> SyncRequest<(KeychainKind, u32)> + Send + Sync, + { + let build_request = &build_request; + let caches = self.caches.clone(); + let response = self + .inner + .run("sync_wallet", move |conn| { + let request = build_request(); + let glue = SyncGlue::new(conn, caches.clone()); + async move { glue.sync(request, batch_size, true).await } + }) + .await?; + + Ok(response) + } +} + +impl EstimateFeeRate for Client { + async fn estimate_feerate(&self, target_block: u32) -> Result { + let (conservative, histogram) = tokio::join!( + self.estimate_fee_rate(target_block), + self.estimate_fee_rate_from_histogram(target_block) + ); + + match (conservative, histogram) { + (Ok(conservative), Ok(histogram)) => { + tracing::debug!( + electrum_conservative_fee_rate_sat_vb = conservative.to_sat_per_vb_ceil(), + electrum_histogram_fee_rate_sat_vb = histogram.to_sat_per_vb_ceil(), + "Successfully fetched fee rates from both sources. We will use the higher one" + ); + Ok(conservative.max(histogram)) + } + (Err(conservative_error), Ok(histogram)) => { + tracing::warn!( + ?conservative_error, + electrum_histogram_fee_rate_sat_vb = histogram.to_sat_per_vb_ceil(), + "Failed to fetch conservative fee rate, using histogram fee rate" + ); + Ok(histogram) + } + (Ok(conservative), Err(histogram_error)) => { + tracing::warn!( + ?histogram_error, + electrum_conservative_fee_rate_sat_vb = conservative.to_sat_per_vb_ceil(), + "Failed to fetch histogram fee rate, using conservative fee rate" + ); + Ok(conservative) + } + (Err(conservative_error), Err(histogram_error)) => Err(conservative_error + .context(histogram_error) + .context( + "Failed to fetch both the conservative and histogram fee rates from Electrum", + )), + } + } + + async fn min_relay_fee(&self) -> Result { + Client::min_relay_fee(self).await + } +} + +/// Merge history entries by txid, keeping the highest-height entry for each. +fn merge_history(entries: Vec) -> Vec { + let mut best: BTreeMap = BTreeMap::new(); + for entry in entries { + best.entry(entry.txid) + .and_modify(|current| { + if entry.height > current.height { + *current = entry; + } + }) + .or_insert(entry); + } + best.into_values().collect() +} + +/// Whether a server error indicates the transaction does not exist. +fn is_tx_not_found(error: &Error) -> bool { + let Some(json) = error.response_json() else { + return false; + }; + + if json.contains("No such mempool or blockchain transaction") + || json.contains("missing transaction") + { + return true; + } + + if let Some(code) = extract_rpc_error_code(json) { + return code == i64::from(RpcErrorCode::RpcInvalidAddressOrKey); + } + + false +} + +/// Per-connection chain-sync engine: a faithful async re-port of `bdk_electrum`'s +/// `BdkElectrumClient` against the streaming client and our shared caches. +pub(crate) struct SyncGlue { + conn: Arc, + caches: Arc, +} + +struct SpkScanState { + unused_spk_count: usize, + last_active_index: Option, + stop_gap: usize, +} + +enum BatchOutcome { + Continue, + Stop, +} + +impl SyncGlue { + pub(crate) fn new(conn: Arc, caches: Arc) -> Self { + Self { conn, caches } + } + + async fn full_scan( + &self, + mut request: FullScanRequest, + stop_gap: usize, + batch_size: usize, + fetch_prev_txouts: bool, + ) -> Result, Error> { + let start_time = request.start_time(); + + let tip_and_latest_blocks = match request.chain_tip() { + Some(chain_tip) => Some(self.fetch_tip_and_latest_blocks(chain_tip).await?), + None => None, + }; + + let mut tx_update = TxUpdate::::default(); + let mut last_active_indices = BTreeMap::::new(); + let mut pending_anchors = Vec::new(); + + for keychain in request.keychains() { + let mut state = SpkScanState { + unused_spk_count: 0, + last_active_index: None, + stop_gap, + }; + + loop { + let batch: Vec<(u32, SpkWithExpectedTxids)> = { + let mut spks = request.iter_spks(keychain.clone()); + (0..batch_size) + .map_while(|_| spks.next()) + .map(|(i, spk)| (i, SpkWithExpectedTxids::from(spk))) + .collect() + }; + + if batch.is_empty() { + break; + } + + if let BatchOutcome::Stop = self + .process_spk_batch( + start_time, + &mut tx_update, + batch, + &mut pending_anchors, + &mut state, + ) + .await? + { + break; + } + } + + if let Some(last_active_index) = state.last_active_index { + last_active_indices.insert(keychain, last_active_index); + } + } + + if fetch_prev_txouts { + self.fetch_prev_txout(&mut tx_update).await?; + } + + self.apply_anchors(&mut tx_update, &pending_anchors).await?; + + let chain_update = match tip_and_latest_blocks { + Some((chain_tip, latest_blocks)) => Some(chain_update( + chain_tip, + &latest_blocks, + tx_update.anchors.iter().cloned(), + )), + None => None, + }; + + Ok(FullScanResponse { + tx_update, + chain_update, + last_active_indices, + }) + } + + async fn sync( + &self, + mut request: SyncRequest<(KeychainKind, u32)>, + batch_size: usize, + fetch_prev_txouts: bool, + ) -> Result { + let start_time = request.start_time(); + + let tip_and_latest_blocks = match request.chain_tip() { + Some(chain_tip) => Some(self.fetch_tip_and_latest_blocks(chain_tip).await?), + None => None, + }; + + let mut tx_update = TxUpdate::::default(); + let mut pending_anchors = Vec::new(); + + let mut state = SpkScanState { + unused_spk_count: 0, + last_active_index: None, + stop_gap: usize::MAX, + }; + let mut spk_index = 0u32; + loop { + let batch: Vec<(u32, SpkWithExpectedTxids)> = { + let mut spks = request.iter_spks_with_expected_txids(); + (0..batch_size) + .map_while(|_| spks.next()) + .map(|spk| { + let indexed = (spk_index, spk); + spk_index += 1; + indexed + }) + .collect() + }; + + if batch.is_empty() { + break; + } + + self.process_spk_batch( + start_time, + &mut tx_update, + batch, + &mut pending_anchors, + &mut state, + ) + .await?; + } + + let txids: Vec = request.iter_txids().collect(); + self.populate_with_txids(start_time, &mut tx_update, txids, &mut pending_anchors) + .await?; + + let outpoints: Vec = request.iter_outpoints().collect(); + self.populate_with_outpoints(start_time, &mut tx_update, outpoints, &mut pending_anchors) + .await?; + + if fetch_prev_txouts { + self.fetch_prev_txout(&mut tx_update).await?; + } + + self.apply_anchors(&mut tx_update, &pending_anchors).await?; + + let chain_update = match tip_and_latest_blocks { + Some((chain_tip, latest_blocks)) => Some(chain_update( + chain_tip, + &latest_blocks, + tx_update.anchors.iter().cloned(), + )), + None => None, + }; + + Ok(SyncResponse { + tx_update, + chain_update, + }) + } + + async fn apply_anchors( + &self, + tx_update: &mut TxUpdate, + pending_anchors: &[(Txid, usize)], + ) -> Result<(), Error> { + if pending_anchors.is_empty() { + return Ok(()); + } + let anchors = self.batch_fetch_anchors(pending_anchors).await?; + for (txid, anchor) in anchors { + tx_update.anchors.insert((anchor, txid)); + } + Ok(()) + } + + async fn fetch_tx(&self, txid: Txid) -> Result, Error> { + if let Some(tx) = self.caches.txs.lock().expect("tx cache poisoned").get(&txid) { + return Ok(tx.clone()); + } + + let full = self.conn.request(GetTx { txid }).await?; + let tx = Arc::new(full.tx); + self.caches + .txs + .lock() + .expect("tx cache poisoned") + .insert(txid, tx.clone()); + Ok(tx) + } + + async fn process_spk_batch( + &self, + start_time: u64, + tx_update: &mut TxUpdate, + batch: Vec<(u32, SpkWithExpectedTxids)>, + pending_anchors: &mut Vec<(Txid, usize)>, + state: &mut SpkScanState, + ) -> Result { + let histories = futures::future::join_all( + batch + .iter() + .map(|(_, spk)| self.conn.request(GetHistory::from_script(spk.spk.clone()))), + ) + .await; + + for ((spk_index, spk), history_res) in batch.into_iter().zip(histories) { + let history = history_res?; + + if history.is_empty() { + match state.unused_spk_count.checked_add(1) { + Some(i) if i < state.stop_gap => state.unused_spk_count = i, + _ => return Ok(BatchOutcome::Stop), + } + } else { + state.last_active_index = Some(spk_index); + state.unused_spk_count = 0; + } + + let history_set: HashSet = history.iter().map(|tx| tx.txid()).collect(); + for &txid in spk.expected_txids.difference(&history_set) { + tx_update.evicted_ats.insert((txid, start_time)); + } + + for tx in history { + let txid = tx.txid(); + tx_update.txs.push(self.fetch_tx(txid).await?); + let height = tx.electrum_height(); + if height > 0 { + pending_anchors.push((txid, height as usize)); + } else { + tx_update.seen_ats.insert((txid, start_time)); + } + } + } + + Ok(BatchOutcome::Continue) + } + + async fn populate_with_txids( + &self, + start_time: u64, + tx_update: &mut TxUpdate, + txids: Vec, + pending_anchors: &mut Vec<(Txid, usize)>, + ) -> Result<(), Error> { + let mut txs = Vec::<(Txid, Arc)>::new(); + let mut scripts = Vec::new(); + for txid in txids { + match self.fetch_tx(txid).await { + Ok(tx) => { + let spk = tx + .output + .first() + .expect("tx must have an output") + .script_pubkey + .clone(); + txs.push((txid, tx)); + scripts.push(spk); + } + // A "not found" (server response) error means the txid is unknown: skip it. + Err(Error::Response(_)) => continue, + Err(e) => return Err(e), + } + } + + let histories = futures::future::join_all( + scripts + .iter() + .map(|spk| self.conn.request(GetHistory::from_script(spk.clone()))), + ) + .await; + + for ((txid, tx), history_res) in txs.into_iter().zip(histories) { + let history = history_res?; + if let Some(entry) = history.into_iter().find(|entry| entry.txid() == txid) { + let height = entry.electrum_height(); + if height > 0 { + pending_anchors.push((txid, height as usize)); + } else { + tx_update.seen_ats.insert((txid, start_time)); + } + } + tx_update.txs.push(tx); + } + + Ok(()) + } + + async fn populate_with_outpoints( + &self, + start_time: u64, + tx_update: &mut TxUpdate, + outpoints: Vec, + pending_anchors: &mut Vec<(Txid, usize)>, + ) -> Result<(), Error> { + let mut ops_spks_txs = Vec::new(); + for op in outpoints { + if let Ok(tx) = self.fetch_tx(op.txid).await { + if let Some(txout) = tx.output.get(op.vout as usize) { + ops_spks_txs.push((op, txout.script_pubkey.clone(), tx)); + } + } + } + + let unique_spks: Vec = ops_spks_txs + .iter() + .map(|(_, spk, _)| spk.clone()) + .collect::>() + .into_iter() + .collect(); + + let histories = futures::future::join_all( + unique_spks + .iter() + .map(|spk| self.conn.request(GetHistory::from_script(spk.clone()))), + ) + .await; + + let mut spk_map: HashMap> = HashMap::new(); + for (spk, history_res) in unique_spks.into_iter().zip(histories) { + spk_map.insert(spk, history_res?); + } + + for (outpoint, spk, tx) in ops_spks_txs { + let Some(spk_history) = spk_map.get(&spk) else { + continue; + }; + + let mut has_residing = false; + let mut has_spending = false; + + for res in spk_history { + if has_residing && has_spending { + break; + } + let res_txid = res.txid(); + + if !has_residing && res_txid == outpoint.txid { + has_residing = true; + tx_update.txs.push(tx.clone()); + let height = res.electrum_height(); + if height > 0 { + pending_anchors.push((res_txid, height as usize)); + } else { + tx_update.seen_ats.insert((res_txid, start_time)); + } + } + + if !has_spending && res_txid != outpoint.txid { + let res_tx = self.fetch_tx(res_txid).await?; + has_spending = res_tx + .input + .iter() + .any(|txin| txin.previous_output == outpoint); + if !has_spending { + continue; + } + tx_update.txs.push(res_tx); + let height = res.electrum_height(); + if height > 0 { + pending_anchors.push((res_txid, height as usize)); + } else { + tx_update.seen_ats.insert((res_txid, start_time)); + } + } + } + } + + Ok(()) + } + + async fn batch_fetch_anchors( + &self, + txs_with_heights: &[(Txid, usize)], + ) -> Result, Error> { + let mut results = Vec::with_capacity(txs_with_heights.len()); + let mut to_fetch = Vec::new(); + + let mut needed_heights: Vec = + txs_with_heights.iter().map(|&(_, h)| h as u32).collect(); + needed_heights.sort_unstable(); + needed_heights.dedup(); + + let mut height_to_hash = HashMap::with_capacity(needed_heights.len()); + + let mut missing_heights = Vec::new(); + { + let cache = self.caches.headers.lock().expect("header cache poisoned"); + for &height in &needed_heights { + if let Some(header) = cache.get(&height) { + height_to_hash.insert(height, header.block_hash()); + } else { + missing_heights.push(height); + } + } + } + + if !missing_heights.is_empty() { + let headers = futures::future::join_all( + missing_heights + .iter() + .map(|&height| self.conn.request(HeaderReq { height })), + ) + .await; + + let mut cache = self.caches.headers.lock().expect("header cache poisoned"); + for (height, header_res) in missing_heights.into_iter().zip(headers) { + let header = header_res?.header; + height_to_hash.insert(height, header.block_hash()); + cache.insert(height, header); + } + } + + { + let anchor_cache = self.caches.anchors.lock().expect("anchor cache poisoned"); + for &(txid, height) in txs_with_heights { + let hash = height_to_hash[&(height as u32)]; + if let Some(anchor) = anchor_cache.get(&(txid, hash)) { + results.push((txid, *anchor)); + } else { + to_fetch.push((txid, height)); + } + } + } + + let proofs = futures::future::join_all(to_fetch.iter().map(|&(txid, height)| { + self.conn.request(GetTxMerkle { + txid, + height: height as u32, + }) + })) + .await; + + for ((txid, height), proof_res) in to_fetch.into_iter().zip(proofs) { + let proof = proof_res?; + + let mut header = { + let cache = self.caches.headers.lock().expect("header cache poisoned"); + cache + .get(&(height as u32)) + .copied() + .expect("header already fetched above") + }; + + let mut valid = proof.expected_merkle_root(txid) == header.merkle_root; + if !valid { + header = self + .conn + .request(HeaderReq { + height: height as u32, + }) + .await? + .header; + self.caches + .headers + .lock() + .expect("header cache poisoned") + .insert(height as u32, header); + valid = proof.expected_merkle_root(txid) == header.merkle_root; + } + + if valid { + let hash = header.block_hash(); + let anchor = ConfirmationBlockTime { + confirmation_time: header.time as u64, + block_id: BlockId { + height: height as u32, + hash, + }, + }; + self.caches + .anchors + .lock() + .expect("anchor cache poisoned") + .insert((txid, hash), anchor); + results.push((txid, anchor)); + } + } + + Ok(results) + } + + async fn fetch_prev_txout( + &self, + tx_update: &mut TxUpdate, + ) -> Result<(), Error> { + let mut no_dup = HashSet::::new(); + let txs: Vec> = tx_update.txs.clone(); + for tx in &txs { + if !tx.is_coinbase() && no_dup.insert(tx.compute_txid()) { + for vin in &tx.input { + let outpoint = vin.previous_output; + let prev_tx = self.fetch_tx(outpoint.txid).await?; + let txout = prev_tx + .output + .get(outpoint.vout as usize) + .ok_or_else(|| { + Error::connection(format!("prevout {outpoint} does not exist")) + })? + .clone(); + tx_update.txouts.insert(outpoint, txout); + } + } + } + Ok(()) + } + + async fn fetch_tip_and_latest_blocks( + &self, + prev_tip: CheckPoint, + ) -> Result<(CheckPoint, BTreeMap), Error> { + let new_tip_height = self.conn.request(HeadersSubscribe).await?.height; + + // If the server's tip is lower than ours, checkpoints need no updating. + if new_tip_height < prev_tip.height() { + return Ok((prev_tip, BTreeMap::new())); + } + + let mut new_blocks = { + let start_height = new_tip_height.saturating_sub(CHAIN_SUFFIX_LENGTH - 1); + let headers = self + .conn + .request(Headers { + start_height, + count: CHAIN_SUFFIX_LENGTH as usize, + }) + .await? + .headers; + (start_height..) + .zip(headers.into_iter().map(|h| h.block_hash())) + .collect::>() + }; + + let agreement_cp = { + let mut agreement_cp = Option::::None; + for cp in prev_tip.iter() { + let cp_block = cp.block_id(); + let hash = match new_blocks.get(&cp_block.height) { + Some(&hash) => hash, + None => { + let hash = self + .conn + .request(HeaderReq { + height: cp_block.height, + }) + .await? + .header + .block_hash(); + new_blocks.insert(cp_block.height, hash); + hash + } + }; + if hash == cp_block.hash { + agreement_cp = Some(cp); + break; + } + } + agreement_cp + .ok_or_else(|| Error::connection("cannot find agreement block with server"))? + }; + + let agreement_height = agreement_cp.height(); + let extension = new_blocks + .iter() + .filter(move |(height, _)| **height > agreement_height) + .map(|(&height, &hash)| BlockId { height, hash }); + let new_tip = agreement_cp + .extend(extension) + .expect("extension heights already checked to be greater than agreement height"); + + Ok((new_tip, new_blocks)) + } +} + +/// Add a corresponding checkpoint per anchor height if it does not yet exist (bounded by +/// `latest_blocks` to keep hashes consistent across re-orgs). +fn chain_update( + mut tip: CheckPoint, + latest_blocks: &BTreeMap, + anchors: impl Iterator, +) -> CheckPoint { + for (anchor, _txid) in anchors { + let height = anchor.block_id.height; + if tip.get(height).is_none() && height <= tip.height() { + let hash = latest_blocks + .get(&height) + .copied() + .unwrap_or(anchor.block_id.hash); + tip = tip.insert(BlockId { hash, height }); + } + } + tip +} diff --git a/bitcoin-wallet/src/lib.rs b/bitcoin-wallet/src/lib.rs index 3f6be7de8..7705196d0 100644 --- a/bitcoin-wallet/src/lib.rs +++ b/bitcoin-wallet/src/lib.rs @@ -1,7 +1,9 @@ mod core; +mod electrum; mod wallet; pub use core::*; +pub use electrum::Client; pub use wallet::*; pub mod primitives; diff --git a/bitcoin-wallet/src/wallet.rs b/bitcoin-wallet/src/wallet.rs index 6183aad3c..8b6e248e3 100644 --- a/bitcoin-wallet/src/wallet.rs +++ b/bitcoin-wallet/src/wallet.rs @@ -1,9 +1,9 @@ -use crate::primitives::{Confirmed, EstimateFeeRate, ScriptStatus, Subscription, Watchable}; -use crate::{BitcoinWallet, BlockHeight, RpcErrorCode, bitcoin_address, parse_rpc_error_code}; +use crate::electrum::Client; +use crate::primitives::{EstimateFeeRate, ScriptStatus, Subscription, Watchable}; +use crate::{BitcoinWallet, bitcoin_address}; use anyhow::{Context, Result, anyhow, bail}; use bdk_chain::CheckPoint; use bdk_chain::spk_client::{SyncRequest, SyncRequestBuilder}; -use bdk_electrum::electrum_client::{ElectrumApi, GetHistoryRes}; use bdk_wallet::KeychainKind; use bdk_wallet::WalletPersister; @@ -17,12 +17,9 @@ use bitcoin::bip32::Xpriv; use bitcoin::{Address, Amount, Transaction, Txid, psbt::Psbt as PartiallySignedTransaction}; use bitcoin::{Psbt, ScriptBuf, Weight}; use derive_builder::Builder; -use electrum_pool::ElectrumBalancer; use moka; use rust_decimal::Decimal; use rust_decimal::prelude::*; -use std::collections::BTreeMap; -use std::collections::HashMap; use std::fmt::Debug; use std::path::Path; use std::path::PathBuf; @@ -32,7 +29,6 @@ use std::time::Duration; use std::time::Instant; use sync_ext::{CumulativeProgressHandle, InnerSyncCallback, SyncCallbackExt}; use tokio::sync::Mutex as TokioMutex; -use tokio::sync::RwLock as TokioRwLock; use tokio::sync::watch; use tracing::{Instrument, debug_span}; @@ -120,23 +116,6 @@ pub struct Wallet { tauri_handle: TauriHandle, } -/// This is our wrapper around a bdk electrum client. -#[derive(Clone)] -pub struct Client { - /// The underlying electrum balancer for load balancing across multiple servers. - inner: Arc, - /// The history of transactions for each script. - script_history: Arc>>>, - /// The subscriptions to the status of transactions. - subscriptions: Arc>>, - /// The time of the last sync. - last_sync: Arc>, - /// How often we sync with the server. - sync_interval: Duration, - /// The height of the latest block we know about. - latest_block_height: Arc>, -} - /// Holds the configuration parameters for creating a Bitcoin wallet. /// The actual Wallet will be constructed from this configuration. #[derive(Builder, Clone)] @@ -557,20 +536,38 @@ impl Wallet { let wallet = Arc::new(wallet); let ph = progress_handle.clone(); - let full_scan_response = client.inner.call_async("full_scan_wallet", move |electrum_client| { - let callback = ph.clone().and_then(|ph| InnerSyncCallback::new(move |consumed, total| { - ph.update(consumed, total); - })).chain(InnerSyncCallback::new(move |consumed, total| { - tracing::debug!( - "Full scanning Bitcoin wallet, currently at index {}. We will scan around {} in total.", - consumed, - total - ); - }).throttle_callback(10.0)).to_full_scan_callback(Self::SCAN_STOP_GAP, 100); - let full_scan = wallet.start_full_scan().inspect(callback); - electrum_client.full_scan(full_scan, Self::SCAN_STOP_GAP as usize, Self::SCAN_BATCH_SIZE as usize, true) - }).await?; + // Rebuilt per attempt so the balancer can retry the full scan against a different server. + let build_request = move || { + let callback = ph + .clone() + .and_then(|ph| { + InnerSyncCallback::new(move |consumed, total| { + ph.update(consumed, total); + }) + }) + .chain( + InnerSyncCallback::new(move |consumed, total| { + tracing::debug!( + "Full scanning Bitcoin wallet, currently at index {}. We will scan around {} in total.", + consumed, + total + ); + }) + .throttle_callback(10.0), + ) + .to_full_scan_callback(Self::SCAN_STOP_GAP, 100); + + wallet.start_full_scan().inspect(callback).build() + }; + + let full_scan_response = client + .full_scan( + build_request, + Self::SCAN_STOP_GAP as usize, + Self::SCAN_BATCH_SIZE as usize, + ) + .await?; // Only create the persister once we have the full scan result let mut persister = persister_constructor()?; @@ -1015,25 +1012,23 @@ impl Wallet { ) -> Result<()> { let callback = Arc::new(SyncMutex::new(callback)); + // Rebuilt per attempt so the balancer can retry the sync against a different server. + let build_request = move || { + let callback = callback.clone(); + sync_request_factory + .clone() + .build() + .inspect(move |_, progress| { + if let Ok(mut guard) = callback.lock() { + guard.call(progress.consumed() as u64, progress.total() as u64); + } + }) + .build() + }; + let sync_response = self .electrum_client - .inner - .call_async("sync_wallet", move |client| { - let sync_request_factory = sync_request_factory.clone(); - let callback = callback.clone(); - - // Build the sync request - let sync_request = sync_request_factory - .build() - .inspect(move |_, progress| { - if let Ok(mut guard) = callback.lock() { - guard.call(progress.consumed() as u64, progress.total() as u64); - } - }) - .build(); - - client.sync(sync_request, Self::SCAN_BATCH_SIZE as usize, true) - }) + .sync(build_request, Self::SCAN_BATCH_SIZE as usize) .await?; // We only acquire the lock after the long running .sync(...) call has finished @@ -1611,494 +1606,6 @@ where } } -impl Client { - /// Create a new client with multiple electrum servers for load balancing. - pub async fn new(electrum_rpc_urls: &[String], sync_interval: Duration) -> Result { - let balancer = ElectrumBalancer::new(electrum_rpc_urls.to_vec()).await?; - let initial_last_sync = Instant::now() - .checked_sub(sync_interval) - .ok_or(anyhow!("failed to set last sync time"))?; - - Ok(Self { - inner: Arc::new(balancer), - script_history: Arc::new(TokioRwLock::new(BTreeMap::new())), - last_sync: Arc::new(SyncMutex::new(initial_last_sync)), - sync_interval, - latest_block_height: Arc::new(SyncMutex::new(BlockHeight::from(0))), - subscriptions: Arc::new(TokioMutex::new(HashMap::new())), - }) - } - - /// Update the client state, if the refresh duration has passed. - /// - /// Optionally force an update even if the sync interval has not passed. - pub async fn update_state(&self, force: bool) -> Result<()> { - let now = Instant::now(); - - if !force { - let last_sync = *self.last_sync.lock().expect("last_sync mutex poisoned"); - if now.duration_since(last_sync) < self.sync_interval { - return Ok(()); - } - } - - self.update_script_histories().await?; - self.update_block_height().await?; - - *self.last_sync.lock().expect("last_sync mutex poisoned") = Instant::now(); - - Ok(()) - } - - /// Update the client state for a single script. - /// - /// As opposed to [`update_state`] this function does not - /// check the time since the last update before refreshing - /// It therefore also does not take a [`force`] parameter - pub async fn update_state_single(&self, script: &dyn Watchable) -> Result<()> { - self.update_script_history(script).await?; - self.update_block_height().await?; - - Ok(()) - } - - /// Update the block height. - async fn update_block_height(&self) -> Result<()> { - let latest_block = self - .inner - .call_async("block_headers_subscribe", |client| { - client.inner.block_headers_subscribe() - }) - .await - .context("Failed to subscribe to header notifications")?; - let latest_block_height = BlockHeight::try_from(latest_block)?; - - let mut current = self - .latest_block_height - .lock() - .expect("latest_block_height mutex poisoned"); - if latest_block_height > *current { - tracing::trace!( - block_height = u32::from(latest_block_height), - "Got notification for new block" - ); - *current = latest_block_height; - } - - Ok(()) - } - - /// Update the script histories. - async fn update_script_histories(&self) -> Result<()> { - let scripts: Vec<_> = self.script_history.read().await.keys().cloned().collect(); - - // No need to do any network request if we have nothing to fetch - if scripts.is_empty() { - return Ok(()); - } - - // Concurrently fetch the script histories from ALL electrum servers - let results = self - .inner - .join_all("batch_script_get_history", { - let scripts = scripts.clone(); - - move |client| { - let script_refs: Vec<_> = scripts.iter().map(|s| s.as_script()).collect(); - client.inner.batch_script_get_history(script_refs) - } - }) - .await?; - - let successful_results: Vec>> = results - .iter() - .filter_map(|r| r.as_ref().ok()) - .cloned() - .collect(); - - // If we didn't get a single successful request, we have to fail - if successful_results.is_empty() { - if let Some(Err(e)) = results.into_iter().find(|r| r.is_err()) { - return Err(e.into()); - } - } - - // Iterate through each script we fetched and find the highest - // returned entry at any Electrum node - let mut script_history = self.script_history.write().await; - for (script_index, script) in scripts.iter().enumerate() { - let all_history_for_script: Vec = successful_results - .iter() - .filter_map(|server_result| server_result.get(script_index)) - .flatten() - .cloned() - .collect(); - - let mut best_history: BTreeMap = BTreeMap::new(); - for item in all_history_for_script { - best_history - .entry(item.tx_hash) - .and_modify(|current| { - if item.height > current.height { - *current = item.clone(); - } - }) - .or_insert(item); - } - - let final_history: Vec = best_history.into_values().collect(); - script_history.insert(script.clone(), final_history); - } - - Ok(()) - } - - /// Update the script history of a single script. - pub async fn update_script_history(&self, script: &dyn Watchable) -> Result<()> { - let (script_buf, _) = script.script_and_txid(); - let script_clone = script_buf.clone(); - - // Call all electrum servers in parallel to get script history. - let results = self - .inner - .join_all("script_get_history", move |client| { - client.inner.script_get_history(script_clone.as_script()) - }) - .await?; - - // Collect all successful history entries from all servers. - let mut all_history_items: Vec = Vec::new(); - let mut any_success = false; - let mut first_error = None; - - for result in results { - match result { - Ok(history) => { - any_success = true; - all_history_items.extend(history); - } - Err(e) => { - if first_error.is_none() { - first_error = Some(e); - } - } - } - } - - // If any of the calls succeeded, that is fine. Only if none - // succeeded we return the error. - if !any_success && let Some(err) = first_error { - return Err(err.into()); - } - - // Use a map to find the best (highest confirmation) entry for each transaction. - let mut best_history: BTreeMap = BTreeMap::new(); - for item in all_history_items { - best_history - .entry(item.tx_hash) - .and_modify(|current| { - if item.height > current.height { - *current = item.clone(); - } - }) - .or_insert(item); - } - - let final_history: Vec = best_history.into_values().collect(); - - self.script_history - .write() - .await - .insert(script_buf, final_history); - - Ok(()) - } - - /// Broadcast a transaction to all known electrum servers in parallel. - /// Returns the results from all servers - at least one success indicates successful broadcast. - pub async fn transaction_broadcast_all( - &self, - transaction: &Transaction, - ) -> Result>> { - // Broadcast to all electrum servers in parallel - let results = self.inner.broadcast_all(transaction.clone()).await?; - - // Add the transaction to the cache if at least one broadcast succeeded - if results.iter().any(|r| r.is_ok()) { - // Note: Perhaps it is better to only populate caches of the Electrum nodes - // that accepted our transaction? - self.inner.populate_tx_cache(vec![transaction.clone()]); - } - - Ok(results) - } - - /// Get the status of a script. - pub async fn status_of_script( - &self, - script: &dyn Watchable, - force: bool, - ) -> Result { - let (script_buf, txid) = script.script_and_txid(); - - let is_first_time = { - let mut history = self.script_history.write().await; - if history.contains_key(&script_buf) { - false - } else { - history.insert(script_buf.clone(), vec![]); - true - } - }; - - if is_first_time { - // Immediately refetch the status of the script - // when we first subscribe to it. - self.update_state_single(script).await?; - } else if force { - // Immediately refetch the status of the script - // when [`force`] is set to true - self.update_state_single(script).await?; - } else { - // Otherwise, don't force a refetch. - self.update_state(false).await?; - } - - let history_guard = self.script_history.read().await; - let history = history_guard.get(&script_buf); - - let history_of_tx: Vec<&GetHistoryRes> = history - .into_iter() - .flatten() - .filter(|entry| entry.tx_hash == txid) - .collect(); - - // Destructure history_of_tx into the last entry and the rest. - let [rest @ .., last] = history_of_tx.as_slice() else { - // If there is no history of the transaction, it is unseen. - return Ok(ScriptStatus::Unseen); - }; - - // There should only be one entry per txid, we will ignore the rest - if !rest.is_empty() { - tracing::warn!(%txid, "Found multiple history entries for the same txid. Ignoring all but the last one."); - } - - let latest_block_height = *self - .latest_block_height - .lock() - .expect("latest_block_height mutex poisoned"); - - match last.height { - // If the height is 0 or less, the transaction is still in the mempool. - ..=0 => Ok(ScriptStatus::InMempool), - // Otherwise, the transaction has been included in a block. - height => Ok(ScriptStatus::Confirmed( - Confirmed::from_inclusion_and_latest_block( - u32::try_from(height)?, - u32::from(latest_block_height), - ), - )), - } - } - - /// Get a transaction from the Electrum server. - /// Fails if the transaction is not found. - pub async fn get_tx(&self, txid: Txid) -> Result>> { - match self - .inner - .call_async_with_multi_error("get_raw_transaction", move |client| { - use bitcoin::consensus::Decodable; - client.inner.transaction_get_raw(&txid).and_then(|raw| { - let mut cursor = std::io::Cursor::new(&raw); - bitcoin::Transaction::consensus_decode(&mut cursor).map_err(|e| { - bdk_electrum::electrum_client::Error::Protocol( - format!("Failed to deserialize transaction: {}", e).into(), - ) - }) - }) - }) - .await - { - Ok(tx) => { - let tx = Arc::new(tx); - // Note: Perhaps it is better to only populate caches of the Electrum nodes - // that accepted our transaction? - self.inner.populate_tx_cache(vec![(*tx).clone()]); - Ok(Some(tx)) - } - Err(multi_error) => { - // Check if any error indicates the transaction doesn't exist - let has_not_found = multi_error.any(|error| { - let error_str = error.to_string(); - - // Check for specific error patterns that indicate "not found" - if error_str.contains("\"code\": Number(-5)") - || error_str.contains("No such mempool or blockchain transaction") - || error_str.contains("missing transaction") - { - return true; - } - - // Also try to parse the RPC error code if possible - let err_anyhow = anyhow::anyhow!(error_str); - if let Ok(error_code) = parse_rpc_error_code(&err_anyhow) { - if error_code == i64::from(RpcErrorCode::RpcInvalidAddressOrKey) { - return true; - } - } - - false - }); - - if has_not_found { - tracing::trace!( - txid = %txid, - error_count = multi_error.len(), - "Transaction not found indicated by one or more Electrum servers" - ); - Ok(None) - } else { - let err = anyhow::anyhow!(multi_error); - Err(err.context("Failed to get transaction from the Electrum server")) - } - } - } - } - - /// Estimate the fee rate to be included in a block at the given offset. - /// Calls: https://electrum-protocol.readthedocs.io/en/latest/protocol-methods.html#blockchain.estimatefee - /// Calls under the hood: https://developer.bitcoin.org/reference/rpc/estimatesmartfee.html - /// - /// This uses estimatesmartfee of bitcoind - pub async fn estimate_fee_rate(&self, target_block: u32) -> Result { - // Get the fee rate in Bitcoin per kilobyte - let btc_per_kvb = self - .inner - .call_async("estimate_fee", move |client| { - client.inner.estimate_fee(target_block as usize) - }) - .await?; - - // If the fee rate is less than 0, return an error - // The Electrum server returns a value <= 0 if it cannot estimate the fee rate. - // See: https://github.com/romanz/electrs/blob/ed0ef2ee22efb45fcf0c7f3876fd746913008de3/src/electrum.rs#L239-L245 - // https://github.com/romanz/electrs/blob/ed0ef2ee22efb45fcf0c7f3876fd746913008de3/src/electrum.rs#L31 - if btc_per_kvb <= 0.0 { - return Err(anyhow!( - "Fee rate returned by Electrum server is less than 0" - )); - } - - // Convert to sat / kB without ever constructing an Amount from the float - // Simply by multiplying the float with the satoshi value of 1 BTC. - // Truncation is allowed here because we are converting to sats and rounding down sats will - // not lose us any precision (because there is no fractional satoshi). - #[allow( - clippy::cast_possible_truncation, - clippy::cast_sign_loss, - clippy::cast_precision_loss - )] - let sats_per_kvb = (btc_per_kvb * Amount::ONE_BTC.to_sat() as f64).ceil() as u64; - - // Convert to sat / kwu (kwu = kB × 4) - let sat_per_kwu = sats_per_kvb / 4; - - // Construct the fee rate - let fee_rate = FeeRate::from_sat_per_kwu(sat_per_kwu); - - Ok(fee_rate) - } - - /// Calculates the fee_rate needed to be included in a block at the given offset. - /// We calculate how many vMB we are away from the tip of the mempool. - /// This method adapts faster to sudden spikes in the mempool. - async fn estimate_fee_rate_from_histogram(&self, target_block: u32) -> Result { - // Assume we want to get into the next block: - // We want to be 80% of the block size away from the tip of the mempool. - const HISTOGRAM_SAFETY_MARGIN: f32 = 0.8; - - // First we fetch the fee histogram from the Electrum server - let fee_histogram = self - .inner - .call_async("get_fee_histogram", move |client| { - client.inner.raw_call("mempool.get_fee_histogram", vec![]) - }) - .await?; - - // Parse the histogram as array of [fee, vsize] pairs - let histogram: Vec<(f64, u64)> = serde_json::from_value(fee_histogram)?; - - // If the histogram is empty, we return an error - if histogram.is_empty() { - return Err(anyhow!( - "The mempool seems to be empty therefore we cannot estimate the fee rate from the histogram" - )); - } - - // Sort the histogram by fee rate - let mut histogram = histogram; - histogram.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); - - // Estimate block size (typically ~1MB = 1,000,000 vbytes) - let estimated_block_size = 1_000_000u64; - #[allow(clippy::cast_precision_loss)] - let target_distance_from_tip = - (estimated_block_size * target_block as u64) as f32 * HISTOGRAM_SAFETY_MARGIN; - - // Find cumulative vsize and corresponding fee rate - let mut cumulative_vsize = 0u64; - for (fee_rate, vsize) in histogram.clone() { - cumulative_vsize += vsize; - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - if cumulative_vsize >= target_distance_from_tip as u64 { - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let sat_per_vb = fee_rate.ceil() as u64; - return FeeRate::from_sat_per_vb(sat_per_vb) - .context("Failed to create fee rate from histogram"); - } - } - - // If we get here, the entire mempool is less than the target distance from the tip. - // We return the lowest fee rate in the histogram. - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let sat_per_vb = histogram - .first() - .expect("The histogram should not be empty") - .0 - .ceil() as u64; - FeeRate::from_sat_per_vb(sat_per_vb) - .context("Failed to create fee rate from histogram (all mempool is less than the target distance from the tip)") - } - - /// Get the minimum relay fee rate from the Electrum server. - async fn min_relay_fee(&self) -> Result { - let min_relay_btc_per_kvb = self - .inner - .call_async("relay_fee", |client| client.inner.relay_fee()) - .await?; - - // Convert to sat / kB without ever constructing an Amount from the float - // Simply by multiplying the float with the satoshi value of 1 BTC. - // Truncation is allowed here because we are converting to sats and rounding down sats will - // not lose us any precision (because there is no fractional satoshi). - #[allow( - clippy::cast_possible_truncation, - clippy::cast_sign_loss, - clippy::cast_precision_loss - )] - let sats_per_kvb = (min_relay_btc_per_kvb * Amount::ONE_BTC.to_sat() as f64).ceil() as u64; - - // Convert to sat / kwu (kwu = kB × 4) - let sat_per_kwu = sats_per_kvb / 4; - - // Construct the fee rate - let fee_rate = FeeRate::from_sat_per_kwu(sat_per_kwu); - - Ok(fee_rate) - } -} - #[derive(Clone)] pub struct SyncRequestBuilderFactory { chain_tip: bdk_wallet::chain::CheckPoint, @@ -2203,59 +1710,6 @@ impl BitcoinWallet for Wallet { } } -impl EstimateFeeRate for Client { - async fn estimate_feerate(&self, target_block: u32) -> Result { - // Now that the Electrum client methods are async, we can parallelize the calls - let (electrum_conservative_fee_rate, electrum_histogram_fee_rate) = tokio::join!( - self.estimate_fee_rate(target_block), - self.estimate_fee_rate_from_histogram(target_block) - ); - - match (electrum_conservative_fee_rate, electrum_histogram_fee_rate) { - // If both the histogram and conservative fee rate are successful, we use the higher one - (Ok(electrum_conservative_fee_rate), Ok(electrum_histogram_fee_rate)) => { - tracing::debug!( - electrum_conservative_fee_rate_sat_vb = - electrum_conservative_fee_rate.to_sat_per_vb_ceil(), - electrum_histogram_fee_rate_sat_vb = - electrum_histogram_fee_rate.to_sat_per_vb_ceil(), - "Successfully fetched fee rates from both sources. We will use the higher one" - ); - - Ok(electrum_conservative_fee_rate.max(electrum_histogram_fee_rate)) - } - // If the conservative fee rate fails, we use the histogram fee rate - (Err(electrum_conservative_fee_rate_error), Ok(electrum_histogram_fee_rate)) => { - tracing::warn!( - electrum_conservative_fee_rate_error = ?electrum_conservative_fee_rate_error, - electrum_histogram_fee_rate_sat_vb = electrum_histogram_fee_rate.to_sat_per_vb_ceil(), - "Failed to fetch conservative fee rate, using histogram fee rate" - ); - Ok(electrum_histogram_fee_rate) - } - // If the histogram fee rate fails, we use the conservative fee rate - (Ok(electrum_conservative_fee_rate), Err(electrum_histogram_fee_rate_error)) => { - tracing::warn!( - electrum_histogram_fee_rate_error = ?electrum_histogram_fee_rate_error, - electrum_conservative_fee_rate_sat_vb = electrum_conservative_fee_rate.to_sat_per_vb_ceil(), - "Failed to fetch histogram fee rate, using conservative fee rate" - ); - Ok(electrum_conservative_fee_rate) - } - // If both the histogram and conservative fee rate fail, we return an error - (Err(electrum_conservative_fee_rate_error), Err(electrum_histogram_fee_rate_error)) => { - Err(electrum_conservative_fee_rate_error - .context(electrum_histogram_fee_rate_error) - .context("Failed to fetch both the conservative and histogram fee rates from Electrum")) - } - } - } - - async fn min_relay_fee(&self) -> Result { - Client::min_relay_fee(self).await - } -} - /// Extension trait for our custom concurrent sync implementation. mod sync_ext { use std::collections::HashMap; diff --git a/electrum-pool/Cargo.toml b/electrum-pool/Cargo.toml index 60b2473f6..04ce17939 100644 --- a/electrum-pool/Cargo.toml +++ b/electrum-pool/Cargo.toml @@ -5,10 +5,18 @@ authors = ["eigenwallet Team "] edition = "2024" [dependencies] +anyhow = { workspace = true } backoff = { workspace = true } -bdk_electrum = { workspace = true, features = ["use-rustls-ring"] } bitcoin = { workspace = true } +electrum_streaming_client = { workspace = true } futures = { workspace = true } once_cell = { workspace = true } -tokio = { workspace = true } +rustls-native-certs = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["net", "io-util"] } +tokio-rustls = { workspace = true } tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "test-util"] } diff --git a/electrum-pool/src/connection.rs b/electrum-pool/src/connection.rs new file mode 100644 index 000000000..43fbdf38f --- /dev/null +++ b/electrum-pool/src/connection.rs @@ -0,0 +1,191 @@ +use std::sync::Arc; +use std::time::Duration; + +use electrum_streaming_client::client::AsyncRequestError; +use electrum_streaming_client::{AsyncClient, RequestExt}; +use once_cell::sync::OnceCell; +use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::pki_types::ServerName; +use tokio_rustls::rustls::{ClientConfig, RootCertStore}; + +use crate::Error; + +/// A single live connection to one Electrum server. +/// +/// Wraps an [`AsyncClient`] over a TCP or TLS stream together with the spawned worker task that +/// drives the socket. Requests are issued with [`Connection::request`] and time out after the +/// configured duration; a timed-out or transport-failed request yields an [`Error::Connection`] +/// so the balancer can fail over and reconnect. +pub struct Connection { + url: String, + client: AsyncClient, + worker: tokio::task::JoinHandle>, + request_timeout: Duration, +} + +impl Connection { + /// Connect to the given `url` (`tcp://` or `ssl://`), spawning the client worker on the current + /// tokio runtime. The whole connect (incl. TLS handshake) is bounded by `request_timeout`. + pub async fn connect(url: &str, request_timeout: Duration) -> Result { + let target = ConnectionTarget::parse(url)?; + + let tcp = tokio::time::timeout( + request_timeout, + TcpStream::connect((target.host.as_str(), target.port)), + ) + .await + .map_err(|_| Error::connection(format!("Timed out connecting to {url}")))? + .map_err(|e| Error::connection(dns_hint(url, e)))?; + + let _ = tcp.set_nodelay(true); + + let (client, worker) = if target.use_tls { + let connector = TlsConnector::from(tls_config()); + let server_name = ServerName::try_from(target.host.clone()) + .map_err(|e| Error::connection(format!("Invalid TLS server name for {url}: {e}")))?; + let tls = tokio::time::timeout(request_timeout, connector.connect(server_name, tcp)) + .await + .map_err(|_| Error::connection(format!("Timed out during TLS handshake to {url}")))? + .map_err(|e| Error::connection(format!("TLS handshake failed for {url}: {e}")))?; + let (reader, writer) = tokio::io::split(tls); + spawn_client(reader, writer) + } else { + let (reader, writer) = tcp.into_split(); + spawn_client(reader, writer) + }; + + Ok(Self { + url: url.to_string(), + client, + worker, + request_timeout, + }) + } + + /// Issue a single tracked request and await the typed response, bounded by the request timeout. + pub async fn request(&self, req: Req) -> Result + where + Req: RequestExt + Send + Sync + 'static, + Req::Response: Send, + { + match tokio::time::timeout(self.request_timeout, self.client.send_request(req)).await { + Ok(Ok(resp)) => Ok(resp), + Ok(Err(AsyncRequestError::Response(resp_err))) => Err(Error::response(&resp_err)), + Ok(Err(AsyncRequestError::Canceled)) => { + Err(Error::connection("Request canceled (connection closed)")) + } + Ok(Err(AsyncRequestError::Dispatch(e))) => { + Err(Error::connection(format!("Failed to dispatch request: {e}"))) + } + Err(_elapsed) => Err(Error::connection("Request timed out")), + } + } + + /// The URL this connection was created from. + pub fn url(&self) -> &str { + &self.url + } +} + +impl Drop for Connection { + fn drop(&mut self) { + self.worker.abort(); + } +} + +fn spawn_client(reader: R, writer: W) -> (AsyncClient, tokio::task::JoinHandle>) +where + R: tokio::io::AsyncRead + Send + Unpin + 'static, + W: tokio::io::AsyncWrite + Send + Unpin + 'static, +{ + let (client, mut events, worker) = AsyncClient::new_tokio(reader, writer); + + // We only use request/response (callback-tracked) requests, so no notifications are produced. + // Still drain the event stream so a stray notification can never wedge the worker loop. + tokio::spawn(async move { + use futures::StreamExt; + while events.next().await.is_some() {} + }); + + (client, tokio::spawn(worker)) +} + +struct ConnectionTarget { + host: String, + port: u16, + use_tls: bool, +} + +impl ConnectionTarget { + fn parse(url: &str) -> Result { + let (scheme, rest) = url + .split_once("://") + .ok_or_else(|| Error::connection(format!("Missing scheme in Electrum URL: {url}")))?; + + let use_tls = match scheme { + "tcp" => false, + "ssl" | "tls" => true, + other => { + return Err(Error::connection(format!( + "Unsupported Electrum URL scheme `{other}` in {url}" + ))); + } + }; + + // Strip optional `user:pass@` credentials (only host:port is significant for us). + let host_port = rest.rsplit_once('@').map(|(_, hp)| hp).unwrap_or(rest); + + let (host, port) = host_port + .rsplit_once(':') + .ok_or_else(|| Error::connection(format!("Missing port in Electrum URL: {url}")))?; + + let port: u16 = port + .parse() + .map_err(|_| Error::connection(format!("Invalid port in Electrum URL: {url}")))?; + + if host.is_empty() { + return Err(Error::connection(format!("Empty host in Electrum URL: {url}"))); + } + + Ok(Self { + host: host.to_string(), + port, + use_tls, + }) + } +} + +/// Wrap a connect IO error with the legacy DNS-resolution hint for the failure kinds that most +/// commonly indicate an unresolvable/unreachable host. +fn dns_hint(url: &str, e: std::io::Error) -> String { + use std::io::ErrorKind::*; + match e.kind() { + NotFound | TimedOut | ConnectionRefused | ConnectionAborted | Other => { + format!("{url}: {e} (Most likely DNS resolution error)") + } + _ => format!("{url}: {e}"), + } +} + +fn tls_config() -> Arc { + static CONFIG: OnceCell> = OnceCell::new(); + CONFIG + .get_or_init(|| { + let mut roots = RootCertStore::empty(); + let loaded = rustls_native_certs::load_native_certs(); + for cert in loaded.certs { + let _ = roots.add(cert); + } + + let provider = Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()); + let config = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .expect("ring provider supports the safe default protocol versions") + .with_root_certificates(roots) + .with_no_client_auth(); + + Arc::new(config) + }) + .clone() +} diff --git a/electrum-pool/src/lib.rs b/electrum-pool/src/lib.rs index f5b46eb22..61a5e9508 100644 --- a/electrum-pool/src/lib.rs +++ b/electrum-pool/src/lib.rs @@ -1,660 +1,375 @@ -use backoff::{Error as BackoffError, ExponentialBackoff}; -use bdk_electrum::BdkElectrumClient; -use bdk_electrum::electrum_client::{Client, ConfigBuilder, ElectrumApi, Error}; -use bitcoin::Transaction; -use futures::future::join_all; -use once_cell::sync::OnceCell; +//! Async, multi-server Electrum client pool built on [`electrum_streaming_client`]. +//! +//! [`ElectrumBalancer`] owns one lazily-established [`Connection`] per server URL and runs +//! operations against them with sticky round-robin failover: it stays on the current server while +//! it succeeds and advances to the next on error, retrying with exponential backoff until every +//! server has been tried at least once (or `min_retries`, whichever is larger). + +mod connection; + +pub use connection::Connection; + +use std::future::Future; +use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; use std::time::Duration; -use std::time::Instant; -use tokio::task::spawn_blocking; -use tracing::{debug, error, instrument, trace, warn}; -/// Round-robin load balancer for Electrum connections. -/// -/// The balancer will try each Electrum node until the provided -/// closure succeeds or all nodes have returned an I/O error. -/// Any non I/O error is immediately returned to the caller. -/// -/// Clients are created lazily on first use to avoid blocking during initialization. -pub struct ElectrumBalancer> -where - C: ElectrumClientLike, -{ - urls: Vec, - #[allow(clippy::type_complexity)] - clients: Arc>>>>>, - next: AtomicUsize, - config: ElectrumBalancerConfig, - factory: Arc + Send + Sync>, +use bitcoin::Transaction; +use electrum_streaming_client::request::BroadcastTx; +use electrum_streaming_client::{RequestExt, response}; +use futures::future::BoxFuture; +use tokio::sync::Mutex; +use tracing::{debug, instrument, trace, warn}; + +/// Error from a single Electrum operation against one server. +#[derive(Debug, Clone)] +pub enum Error { + /// Transport/connection-level failure. The balancer fails over and drops the connection so it + /// is re-established on next use. + Connection(String), + /// The Electrum server returned a JSON-RPC error. Holds the raw error JSON payload as text so + /// that callers (e.g. RPC error-code parsing) can inspect it. + Response(String), } -impl ElectrumBalancer -where - C: ElectrumClientLike, -{ - /// Helper function to get or initialize a client for a given index - fn get_or_init_client_sync(&self, idx: usize) -> Result, Error> { - // We wrap this in a closure to only lock the RwLock for as long as needed - let (client_once_cell, url, config, factory) = { - let clients = self.clients.read().expect("rwlock poisoned").clone(); - - if idx >= clients.len() { - return Err(Error::IOError(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - format!("Index {} out of bounds for {} clients", idx, clients.len()), - ))); - } +impl Error { + pub fn connection(msg: impl Into) -> Self { + Error::Connection(msg.into()) + } - let once_cell = clients[idx].clone(); - let url = self.urls[idx].clone(); - let config = self.config.clone(); - let factory = self.factory.clone(); + /// Build a [`Error::Response`] from the streaming client's server error, extracting the raw + /// JSON payload (its `Display` is `"Response.error: "`). + pub fn response(err: &electrum_streaming_client::ResponseError) -> Self { + let text = err.to_string(); + let json = text + .strip_prefix("Response.error: ") + .unwrap_or(&text) + .to_string(); + Error::Response(json) + } - (once_cell, url, config, factory) - }; + /// Whether this is a transport-level failure that warrants reconnecting the server. + pub fn is_connection(&self) -> bool { + matches!(self, Error::Connection(_)) + } - let client = client_once_cell.get_or_try_init(|| factory.create_client(&url, &config))?; + /// The raw server error JSON payload, if this is a server response error. + pub fn response_json(&self) -> Option<&str> { + match self { + Error::Response(json) => Some(json), + Error::Connection(_) => None, + } + } +} - Ok(client.clone()) +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Connection(msg) => write!(f, "Electrum connection error: {msg}"), + Error::Response(json) => write!(f, "Electrum server error: {json}"), + } } +} - async fn get_or_init_client_async(&self, idx: usize) -> Result, Error> { - let balancer = self.clone(); - spawn_blocking(move || balancer.get_or_init_client_sync(idx)) - .await - .map_err(|e| Error::IOError(std::io::Error::other(e.to_string())))? +impl std::error::Error for Error {} + +/// Factory that establishes a connection to a server URL. +pub trait ConnectionFactory: Send + Sync { + fn connect(&self, url: String, request_timeout: Duration) -> BoxFuture<'static, Result>; +} + +/// Default factory producing real [`Connection`]s. +pub struct DefaultConnectionFactory; + +impl ConnectionFactory for DefaultConnectionFactory { + fn connect( + &self, + url: String, + request_timeout: Duration, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { Connection::connect(&url, request_timeout).await }) } +} - /// Create a new balancer from a list of Electrum URLs with default configuration. - pub async fn new_with_factory( - urls: Vec, - factory: Arc + Send + Sync>, - ) -> Result { - Self::new_with_config_and_factory(urls, ElectrumBalancerConfig::default(), factory).await +/// Configuration for the Electrum balancer. +#[derive(Clone, Debug)] +pub struct ElectrumBalancerConfig { + /// Per-request (and per-connect) timeout. + pub request_timeout: Duration, + /// Minimum number of attempts across all servers before giving up. + pub min_retries: usize, +} + +impl Default for ElectrumBalancerConfig { + fn default() -> Self { + Self { + request_timeout: Duration::from_secs(15), + min_retries: 10, + } } +} - /// Get any client from the balancer - pub async fn get_any_client(&self) -> Result, Error> { - // Try to initialize any client - for idx in 0..self.client_count() { - match self.get_or_init_client_async(idx).await { - Ok(client) => return Ok(client), - Err(e) => { - trace!( - server_url = self.urls[idx], - error = ?e, - "Failed to initialize client, trying next client" - ); - } - } +struct ConnectionSlot { + connection: Mutex>>, +} + +impl ConnectionSlot { + fn new() -> Self { + Self { + connection: Mutex::new(None), } + } +} + +/// Sticky round-robin load balancer over one [`Connection`] per Electrum server. +pub struct ElectrumBalancer { + urls: Vec, + slots: Arc>>, + next: AtomicUsize, + config: ElectrumBalancerConfig, + factory: Arc>, +} + +impl ElectrumBalancer { + /// Create a balancer over the given URLs with default configuration. + pub fn new(urls: Vec) -> Result { + Self::new_with_config(urls, ElectrumBalancerConfig::default()) + } - // Return error if no client could be initialized - Err(Error::IOError(std::io::Error::other( - "No client could be initialized", - ))) + /// Create a balancer over the given URLs with custom configuration. + pub fn new_with_config( + urls: Vec, + config: ElectrumBalancerConfig, + ) -> Result { + Self::new_with_factory(urls, config, Arc::new(DefaultConnectionFactory)) } +} - /// Create a new balancer from a list of Electrum URLs with custom configuration. - /// Clients are initialized lazily on first use. - pub async fn new_with_config_and_factory( +impl ElectrumBalancer +where + C: Send + Sync + 'static, +{ + /// Create a balancer from a connection factory. Connections are established lazily on first use. + pub fn new_with_factory( urls: Vec, config: ElectrumBalancerConfig, - factory: Arc + Send + Sync>, + factory: Arc>, ) -> Result { if urls.is_empty() { - return Err(Error::Protocol("No Electrum URLs provided".into())); + return Err(Error::connection("No Electrum URLs provided")); } debug!( servers = ?urls, server_count = urls.len(), - timeout_seconds = config.request_timeout, + timeout_ms = config.request_timeout.as_millis(), min_retries = config.min_retries, "Initializing Electrum load balancer" ); - // Create OnceCell containers for each URL - clients will be created on first use - let clients: Vec>>> = - urls.iter().map(|_| Arc::new(OnceCell::new())).collect(); + let slots = (0..urls.len()).map(|_| ConnectionSlot::new()).collect(); Ok(Self { urls, - clients: Arc::new(RwLock::new(clients)), + slots: Arc::new(slots), next: AtomicUsize::new(0), config, factory, }) } - /// Get the number of URLs (potential clients) - pub fn client_count(&self) -> usize { + /// The configured server URLs. + pub fn urls(&self) -> &Vec { + &self.urls + } + + /// The number of servers in the pool. + pub fn server_count(&self) -> usize { self.urls.len() } - /// Execute the given closure using one of the Electrum clients asynchronously. - /// - /// If the closure returns an I/O error or certificate error the balancer will try the next - /// node until all nodes have been exhausted. The last encountered error - /// is returned in that case. - #[instrument(level = "debug", skip(self, f), fields(operation = kind, total_urls = self.urls.len(), total_clients = self.client_count()))] - pub async fn call(&self, kind: &str, f: F) -> Result - where - F: Fn(&C) -> Result + Send + Sync + Clone + 'static, - T: Send + 'static, - { - let balancer = self.clone(); - let kind = kind.to_string(); + /// The balancer configuration. + pub fn config(&self) -> &ElectrumBalancerConfig { + &self.config + } - match spawn_blocking(move || balancer.call_sync(&kind, f)).await { - Ok(result) => result.map_err(|multi_error| multi_error.into()), - Err(e) => Err(Error::IOError(std::io::Error::other(e.to_string()))), + async fn get_or_connect(&self, idx: usize) -> Result, Error> { + let mut guard = self.slots[idx].connection.lock().await; + if let Some(connection) = guard.as_ref() { + return Ok(connection.clone()); } - } - /// Execute the given closure using one of the Electrum clients asynchronously. - /// - /// If the closure returns an I/O error or certificate error the balancer will try the next - /// node until all nodes have been exhausted. The last encountered error - /// is returned in that case. - #[instrument(level = "debug", skip(self, f), fields(operation = kind, total_urls = self.urls.len(), total_clients = self.client_count()))] - pub async fn call_async(&self, kind: &str, f: F) -> Result - where - F: Fn(&C) -> Result + Send + Sync + Clone + 'static, - T: Send + 'static, - { - let balancer = self.clone(); - let kind = kind.to_string(); + let connection = self + .factory + .connect(self.urls[idx].clone(), self.config.request_timeout) + .await?; + let connection = Arc::new(connection); + *guard = Some(connection.clone()); + Ok(connection) + } - match spawn_blocking(move || balancer.call_sync(&kind, f)).await { - Ok(result) => result.map_err(|multi_error| multi_error.into()), - Err(e) => Err(Error::IOError(std::io::Error::other(e.to_string()))), - } + async fn invalidate(&self, idx: usize) { + *self.slots[idx].connection.lock().await = None; } - /// Execute the given closure using one of the Electrum clients asynchronously, - /// returning the full MultiError for detailed error analysis. - /// - /// Unlike `call_async`, this method exposes the full MultiError containing all - /// individual failures, allowing the caller to inspect and make decisions based - /// on the specific types of errors encountered. - #[instrument(level = "debug", skip(self, f), fields(operation = kind, total_clients = self.client_count()))] - pub async fn call_async_with_multi_error(&self, kind: &str, f: F) -> Result - where - F: Fn(&C) -> Result + Send + Sync + Clone + 'static, - T: Send + 'static, - { - let balancer = self.clone(); - let kind_string = kind.to_string(); - let kind_for_error = kind.to_string(); - - match spawn_blocking(move || balancer.call_sync(&kind_string, f)).await { - Ok(result) => result, - Err(e) => { - let context = - format!("Failed to spawn blocking task for operation '{kind_for_error}'"); - let error = Error::IOError(std::io::Error::other(e.to_string())); - Err(MultiError::new(vec![error], context)) - } - } + fn advance(&self) { + let count = self.urls.len(); + let _ = self + .next + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| { + Some((current + 1) % count) + }); } - /// Execute the given closure using one of the Electrum clients synchronously. + /// Run the operation against a single server, failing over to the next on error. /// - /// This version blocks for client creation if needed but executes the request synchronously. - /// Used for implementing the ElectrumApi trait. - /// - /// If the closure returns an I/O error or certificate error the balancer will try the next - /// node until all nodes have been exhausted. The last encountered error - /// is returned in that case. - /// - /// Returns `MultiError` containing all individual failures, which can be inspected - /// by the caller or automatically converted to a single `Error` for compatibility. - #[instrument(level = "debug", skip(self, f), fields(operation = kind, total_clients = self.client_count(), min_retries = self.config.min_retries))] - fn call_sync(&self, kind: &str, mut f: F) -> Result + /// Stays on the last successful server (sticky) and advances on failure. Retries with + /// exponential backoff up to `max(min_retries, server_count)` attempts before returning a + /// [`MultiError`] aggregating every failure. + #[instrument(level = "debug", skip(self, f), fields(operation = kind, servers = self.urls.len()))] + pub async fn run(&self, kind: &str, f: F) -> Result where - F: FnMut(&C) -> Result, + F: Fn(Arc) -> Fut + Send + Sync, + Fut: Future> + Send, + T: Send, { - let num_clients = self.client_count(); + let allowed = std::cmp::max(self.config.min_retries, self.urls.len()); let mut errors = Vec::new(); + let mut backoff = Duration::from_millis(100); - // Try all electrum clients at least once, or min_retries (whichever is higher) - let allowed_retries = std::cmp::max(self.config.min_retries, num_clients); - - // Configure exponential backoff - let backoff_policy = ExponentialBackoff { - initial_interval: Duration::from_millis(100), - // 1.5 seconds - max_interval: Duration::from_millis(1500), - // We handle total attempts ourselves - max_elapsed_time: None, - ..ExponentialBackoff::default() - }; - - let operation_with_backoff = || { - if errors.len() >= allowed_retries { - return Err(BackoffError::permanent(())); - } - - // Get current index without incrementing + while errors.len() < allowed { let idx = self.next.load(Ordering::SeqCst); - // Get client for this index - let client = self.get_or_init_client_sync(idx).map_err(|err| { - trace!( - server_url = self.urls[idx], - attempt = errors.len(), - error = ?err, - "Client initialization failed, switching to next client" - ); - - errors.push(err); - - BackoffError::transient(()) - })?; - - // Execute the request synchronously - let start = Instant::now(); - match f(&client) { - Ok(res) => { - trace!( - server_url = self.urls[idx], - attempt = errors.len(), - duration_ms = start.elapsed().as_millis(), - "Electrum operation successful (staying with this client)" - ); - Ok(res) - } - Err(err) => { - trace!( - server_url = self.urls[idx], - attempt = errors.len(), - duration_ms = start.elapsed().as_millis(), - error = ?err, - "Electrum operation failed, switching to next client" - ); - - errors.push(err); - - Err(BackoffError::transient(())) + let connection = match self.get_or_connect(idx).await { + Ok(connection) => connection, + Err(e) => { + trace!(server_url = self.urls[idx], error = %e, "Connection failed, trying next"); + errors.push(e); + self.advance(); + Self::sleep_backoff(&mut backoff).await; + continue; } - } - }; - - // Use backoff::retry for the retry logic with exponential backoff - match backoff::retry_notify( - backoff_policy, - operation_with_backoff, - |_: (), duration: Duration| { - trace!( - backoff_duration_ms = duration.as_millis(), - "Backing off before retry" - ); - - // Advance to next client on failure - self.next - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| { - Some((current + 1) % num_clients) - }) - .expect("fetch_update should never fail"); - }, - ) { - Ok(result) => Ok(result), - Err(_) => { - warn!( - operation = kind, - attempts = errors.len(), - total_attempts = allowed_retries, - total_clients = self.client_count(), - error_count = errors.len(), - all_errors = ?errors, - "All Electrum clients failed after exhausting retry attempts with backoff" - ); - - let context = format!( - "All {} Electrum clients failed after {} attempts for operation '{}'", - self.client_count(), - errors.len(), - kind - ); - - Err(MultiError::new(errors, context)) - } - } - } - - /// Execute the given closure on **all** Electrum nodes in parallel. - /// - /// The closure is executed in a blocking task for each client. - /// The resulting `Result`s are collected and returned in the same - /// order as the nodes were provided during construction. - #[instrument(level = "debug", skip(self, f), fields(operation = kind, total_clients = self.client_count()))] - pub async fn join_all(&self, kind: &str, f: F) -> Result>, Error> - where - F: Fn(&C) -> Result + Send + Sync + Clone + 'static, - T: Send + 'static, - { - let start_time = Instant::now(); - trace!( - operation = kind, - total_clients = self.client_count(), - "Executing parallel requests on electrum clients" - ); + }; - // Create a task for each potential client - let tasks = { - (0..self.client_count()) - .map(|idx| { - let f = f.clone(); - let balancer = self.clone(); - - tokio::spawn(async move { - match balancer.get_or_init_client_async(idx).await { - Ok(client) => tokio::task::spawn_blocking(move || f(&client)) - .await - .map_err(|e| { - Error::IOError(std::io::Error::other(format!("{e:?}"))) - })?, - Err(e) => Err(e), - } - }) - }) - .collect::>() - }; - - // Spawn the threads and wait until they all finish - let spawn_results = join_all(tasks).await; - - let mut results: Vec> = Vec::new(); - for (task_idx, res) in spawn_results.into_iter().enumerate() { - match res { - Ok(r) => results.push(r), - Err(err) if err.is_cancelled() => { - // We one task is cancelled, we do not continue - // Most likely our function got cancelled - return Err(Error::IOError(std::io::Error::other("Task cancelled"))); - } + match f(connection).await { + Ok(value) => return Ok(value), Err(e) => { - trace!(task_index = task_idx, error = ?e, "Failed to spawn thread for parallel request"); + trace!(server_url = self.urls[idx], error = %e, "Operation failed, trying next"); + if e.is_connection() { + self.invalidate(idx).await; + } + errors.push(e); + self.advance(); + if errors.len() < allowed { + Self::sleep_backoff(&mut backoff).await; + } } } } - let success_count = results.iter().filter(|r| r.is_ok()).count(); - let failure_count = results.len() - success_count; - - // Collect errors for detailed logging - let errors: Vec<(usize, &Error)> = results - .iter() - .enumerate() - .filter_map(|(idx, result)| { - if let Err(e) = result { - Some((idx, e)) - } else { - None - } - }) - .collect(); - - if failure_count > 0 { - trace!( - total_duration_ms = start_time.elapsed().as_millis(), - successful_requests = success_count, - failed_requests = failure_count, - total_requests = results.len(), - errors = ?errors, - "Parallel execution completed with errors" - ); - } else { - trace!( - total_duration_ms = start_time.elapsed().as_millis(), - successful_requests = success_count, - total_requests = results.len(), - "Parallel execution completed successfully" - ); - } - - Ok(results) - } - - /// Broadcast the given transaction to all Electrum nodes in parallel. - /// - /// The method returns a list of results in the same order as the - /// configured nodes. Errors for individual nodes do not abort the - /// others. - #[instrument(level = "debug", skip(self, tx), fields(txid = %tx.compute_txid(), total_clients = self.client_count()))] - pub async fn broadcast_all( - &self, - tx: Transaction, - ) -> Result>, Error> { - let txid = tx.compute_txid(); - let start_time = Instant::now(); - - debug!( - txid = %txid, - total_clients = self.client_count(), - "Broadcasting transaction to electrum clients" + warn!( + operation = kind, + attempts = errors.len(), + servers = self.urls.len(), + "All Electrum servers failed after exhausting retries" ); - let results = self - .join_all("transaction_broadcast", move |client| { - client.transaction_broadcast(&tx) - }) - .await?; - - let success_count = results.iter().filter(|r| r.is_ok()).count(); - - if success_count > 0 { - debug!( - txid = %txid, - successful_broadcasts = success_count, - total_attempts = results.len(), - duration_ms = start_time.elapsed().as_millis(), - "Transaction broadcast completed successfully" - ); - } else { - error!( - txid = %txid, - total_attempts = results.len(), - duration_ms = start_time.elapsed().as_millis(), - "Transaction broadcast failed on all servers" - ); - } - - Ok(results) - } - - /// Get the URLs used by this balancer - pub fn urls(&self) -> &Vec { - &self.urls - } - - /// Get the current configuration - pub fn config(&self) -> &ElectrumBalancerConfig { - &self.config - } - - /// Populate the transaction cache for all initialized clients. - pub fn populate_tx_cache(&self, txs: impl IntoIterator>>) { - // Convert transactions to Arc and collect them since we'll use them for each client - let transactions: Vec> = txs.into_iter().map(|tx| tx.into()).collect(); - let clients = self.clients.read().expect("rwlock poisoned"); - - let mut initialized_count = 0; - - // Only populate cache for already initialized clients - for client_once_cell in clients.iter() { - if let Some(client) = client_once_cell.get() { - client.populate_tx_cache(transactions.iter().cloned()); - initialized_count += 1; + Err(MultiError::new( + errors, + format!( + "All {} Electrum servers failed after {} attempts for operation '{}'", + self.urls.len(), + self.urls.len(), + kind + ), + )) + } + + /// Run the operation against every server concurrently, returning one result per server in URL + /// order. + #[instrument(level = "debug", skip(self, f), fields(operation = kind, servers = self.urls.len()))] + pub async fn join_all(&self, kind: &str, f: F) -> Vec> + where + F: Fn(Arc) -> Fut + Send + Sync, + Fut: Future> + Send, + T: Send, + { + let tasks = (0..self.urls.len()).map(|idx| { + let f = &f; + async move { + let connection = self.get_or_connect(idx).await?; + let result = f(connection).await; + if let Err(e) = &result { + if e.is_connection() { + self.invalidate(idx).await; + } + } + result } - } + }); - trace!( - transaction_count = transactions.len(), - initialized_client_count = initialized_count, - total_client_count = clients.len(), - "Populated transaction cache for initialized clients" - ); + futures::future::join_all(tasks).await } -} -impl Clone for ElectrumBalancer -where - C: ElectrumClientLike, -{ - fn clone(&self) -> Self { - Self { - urls: self.urls.clone(), - clients: self.clients.clone(), - next: AtomicUsize::new(self.next.load(Ordering::SeqCst)), - config: self.config.clone(), - factory: self.factory.clone(), - } + async fn sleep_backoff(backoff: &mut Duration) { + tokio::time::sleep(*backoff).await; + *backoff = std::cmp::min(backoff.mul_f64(1.5), Duration::from_millis(1500)); } } -/// Trait abstracting Electrum client operations needed by the balancer -pub trait ElectrumClientLike: Send + Sync + 'static { - /// Broadcast a transaction - fn transaction_broadcast(&self, tx: &Transaction) -> Result; - - /// Populate transaction cache (only for BdkElectrumClient) - fn populate_tx_cache(&self, _txs: impl Iterator>) { - // Default implementation does nothing - } -} - -impl ElectrumClientLike for BdkElectrumClient { - fn transaction_broadcast(&self, tx: &Transaction) -> Result { - self.inner.transaction_broadcast(tx) +impl ElectrumBalancer { + /// Issue a typed request against a single server with failover. + pub async fn request(&self, kind: &str, req: Req) -> Result + where + Req: RequestExt + Clone + Send + Sync + 'static, + Req::Response: Send, + { + self.run(kind, move |connection| { + let req = req.clone(); + async move { connection.request(req).await } + }) + .await } - fn populate_tx_cache(&self, txs: impl Iterator>) { - BdkElectrumClient::populate_tx_cache(self, txs) + /// Issue a typed request against every server concurrently. + pub async fn request_join_all(&self, kind: &str, req: Req) -> Vec> + where + Req: RequestExt + Clone + Send + Sync + 'static, + Req::Response: Send, + { + self.join_all(kind, move |connection| { + let req = req.clone(); + async move { connection.request(req).await } + }) + .await } -} -/// Configuration for the Electrum balancer -#[derive(Clone, Debug)] -pub struct ElectrumBalancerConfig { - /// Timeout for individual requests in seconds - pub request_timeout: u8, - /// Minimum number of retry attempts across all nodes - pub min_retries: usize, -} - -impl Default for ElectrumBalancerConfig { - fn default() -> Self { - Self { - request_timeout: 15, - min_retries: 10, - } + /// Broadcast a transaction to every server concurrently. Returns one result per server. + pub async fn broadcast_all(&self, tx: Transaction) -> Vec> { + self.request_join_all("transaction_broadcast", BroadcastTx(tx)) + .await } -} - -/// Trait for creating Electrum clients -pub trait ElectrumClientFactory { - fn create_client(&self, url: &str, config: &ElectrumBalancerConfig) -> Result, Error>; -} - -/// Default factory for BdkElectrumClient -pub struct BdkElectrumClientFactory; -impl ElectrumClientFactory> for BdkElectrumClientFactory { - fn create_client( + /// Fetch a single script's history from every server concurrently. + pub async fn script_get_history_all( &self, - url: &str, - config: &ElectrumBalancerConfig, - ) -> Result>, Error> { - let client_config = ConfigBuilder::new() - .timeout(Some(config.request_timeout)) - // TODO: Why is this set to 1? - // The goal of this crate is to extract retry logic out of the electrum client library - // and instead handle inside this crate. However, the electrum client library is quite inflexible. - // - // Setting it to 0 causes some bugs, see: https://github.com/bitcoindevkit/rust-electrum-client/issues/186 - .retry(1) - .build(); - - let client = Client::from_config(url, client_config).map_err(|e| { - // Wrap connection errors with DNS resolution context - match &e { - Error::IOError(io_err) if io_err.kind() == std::io::ErrorKind::NotFound => { - Error::IOError(std::io::Error::new( - std::io::ErrorKind::NotFound, - format!("{e} (Most likely DNS resolution error)"), - )) - } - Error::IOError(io_err) - if io_err.kind() == std::io::ErrorKind::TimedOut - || io_err.kind() == std::io::ErrorKind::ConnectionRefused - || io_err.kind() == std::io::ErrorKind::ConnectionAborted - || io_err.kind() == std::io::ErrorKind::Other => - { - Error::IOError(std::io::Error::new( - io_err.kind(), - format!("{e} (Most likely DNS resolution error)"), - )) - } - _ => e, // Pass through other errors unchanged - } - })?; - let bdk_client = BdkElectrumClient::new(client); - - Ok(Arc::new(bdk_client)) - } -} - -// Convenience methods for the default BdkElectrumClient case -impl ElectrumBalancer> { - /// Create a new balancer from a list of Electrum URLs with default configuration. - /// Uses the default BdkElectrumClientFactory. - pub async fn new(urls: Vec) -> Result { - Self::new_with_factory(urls, Arc::new(BdkElectrumClientFactory)).await - } - - /// Create a new balancer from a list of Electrum URLs with custom configuration. - /// Uses the default BdkElectrumClientFactory. - pub async fn new_with_config( - urls: Vec, - config: ElectrumBalancerConfig, - ) -> Result { - Self::new_with_config_and_factory(urls, config, Arc::new(BdkElectrumClientFactory)).await + script: bitcoin::ScriptBuf, + ) -> Vec, Error>> { + use electrum_streaming_client::request::GetHistory; + self.request_join_all("script_get_history", GetHistory::from_script(script)) + .await } } -/// Type alias for the default Electrum balancer using BdkElectrumClient -pub type DefaultElectrumBalancer = ElectrumBalancer>; - -/// Error type that contains multiple Electrum errors from different nodes. +/// Aggregates the per-server failures of a balancer operation. /// -/// This allows the caller to inspect all individual failures while still -/// working with the `?` operator through automatic conversion to a single Error. -#[derive(Debug)] +/// Part of the public API: consumed by RPC error-code parsing and broadcast result handling. +#[derive(Debug, Clone)] pub struct MultiError { pub errors: Vec, pub context: String, } -impl Clone for MultiError { - fn clone(&self) -> Self { - // Clone by converting each error to a string and back to an error - let cloned_errors = self - .errors - .iter() - .map(|e| Error::IOError(std::io::Error::other(e.to_string()))) - .collect(); - - Self { - errors: cloned_errors, - context: self.context.clone(), - } - } -} - impl MultiError { pub fn new(errors: Vec, context: impl Into) -> Self { Self { @@ -663,22 +378,18 @@ impl MultiError { } } - /// Get the number of errors pub fn len(&self) -> usize { self.errors.len() } - /// Check if there are no errors pub fn is_empty(&self) -> bool { self.errors.is_empty() } - /// Get an iterator over the errors pub fn iter(&self) -> impl Iterator { self.errors.iter() } - /// Check if any error matches a predicate pub fn any(&self, predicate: F) -> bool where F: Fn(&Error) -> bool, @@ -686,23 +397,12 @@ impl MultiError { self.errors.iter().any(predicate) } - /// Check if all errors match a predicate pub fn all(&self, predicate: F) -> bool where F: Fn(&Error) -> bool, { self.errors.iter().all(predicate) } - - /// Convert to a single Error (uses the last error, or creates a generic one) - pub fn into_single_error(self) -> Error { - self.errors.into_iter().next_back().unwrap_or_else(|| { - Error::IOError(std::io::Error::other(format!( - "All operations failed: {}", - self.context - ))) - }) - } } impl std::fmt::Display for MultiError { @@ -715,584 +415,168 @@ impl std::fmt::Display for MultiError { } } -impl std::error::Error for MultiError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - // Return the last error as the source - self.errors.last().and_then(|e| e.source()) - } -} - -impl From for Error { - fn from(multi_error: MultiError) -> Self { - multi_error.into_single_error() - } -} - -// Allow ? operator to work on MultiError by converting to Error -impl From for Result { - fn from(multi_error: MultiError) -> Self { - Err(multi_error.into()) - } -} +impl std::error::Error for MultiError {} #[cfg(test)] mod tests { use super::*; - use bitcoin::hashes::Hash; - use bitcoin::{ - Amount, OutPoint, ScriptBuf, Sequence, TxIn, TxOut, Witness, absolute::LockTime, - transaction::Version, - }; - use std::sync::Mutex as StdMutex; - use std::sync::atomic::{AtomicUsize, Ordering}; - - /// Mock client for testing - #[derive(Debug)] - struct MockElectrumClient { - url: String, - fail_count: Arc, - call_count: Arc, - should_fail: bool, - error_type: MockErrorType, - } + use std::sync::atomic::AtomicUsize; - #[derive(Debug, Clone)] - enum MockErrorType { - IOError, - NonRetryable, + struct MockConnection { + url: String, + calls: Arc, + outcome: MockOutcome, } - impl MockElectrumClient { - fn new(url: String) -> Self { - Self { - url, - fail_count: Arc::new(AtomicUsize::new(0)), - call_count: Arc::new(AtomicUsize::new(0)), - should_fail: false, - error_type: MockErrorType::IOError, - } - } - - fn with_failure(mut self, error_type: MockErrorType) -> Self { - self.should_fail = true; - self.error_type = error_type; - self - } - - fn call_count(&self) -> usize { - self.call_count.load(Ordering::SeqCst) - } + #[derive(Clone, Copy)] + enum MockOutcome { + Ok, + ConnectionError, + ResponseError, } - impl ElectrumClientLike for MockElectrumClient { - fn transaction_broadcast(&self, _tx: &Transaction) -> Result { - self.call_count.fetch_add(1, Ordering::SeqCst); - - if self.should_fail { - self.fail_count.fetch_add(1, Ordering::SeqCst); - match self.error_type { - MockErrorType::IOError => Err(Error::IOError(std::io::Error::new( - std::io::ErrorKind::ConnectionRefused, - format!("Mock connection failed for {}", self.url), - ))), - MockErrorType::NonRetryable => Err(Error::Protocol( - format!( - "\"code\": Number(-5) - transaction not found on {}", - self.url - ) - .into(), - )), + impl MockConnection { + async fn call(self: Arc) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + match self.outcome { + MockOutcome::Ok => Ok(self.url.clone()), + MockOutcome::ConnectionError => Err(Error::connection(format!("io {}", self.url))), + MockOutcome::ResponseError => { + Err(Error::Response(format!("{{\"code\":-5}} {}", self.url))) } - } else { - Ok(bitcoin::Txid::from_raw_hash( - bitcoin::hashes::sha256d::Hash::from_byte_array([1; 32]), - )) } } } - /// Mock factory for creating test clients - struct MockElectrumClientFactory { - clients: Arc>>>, + struct MockFactory { + outcomes: std::collections::HashMap, + calls: std::sync::Mutex>>, } - impl MockElectrumClientFactory { - fn new() -> Self { - Self { - clients: Arc::new(StdMutex::new(Vec::new())), - } - } - - fn add_client(&self, client: MockElectrumClient) { - self.clients.lock().unwrap().push(Arc::new(client)); + impl MockFactory { + fn new(outcomes: Vec<(&str, MockOutcome)>) -> Arc { + Arc::new(Self { + outcomes: outcomes + .into_iter() + .map(|(u, o)| (u.to_string(), o)) + .collect(), + calls: std::sync::Mutex::new(std::collections::HashMap::new()), + }) } - fn get_client(&self, idx: usize) -> Option> { - self.clients.lock().unwrap().get(idx).cloned() + fn call_count(&self, url: &str) -> usize { + self.calls + .lock() + .unwrap() + .get(url) + .map(|c| c.load(Ordering::SeqCst)) + .unwrap_or(0) } } - impl ElectrumClientFactory for MockElectrumClientFactory { - fn create_client( + impl ConnectionFactory for MockFactory { + fn connect( &self, - url: &str, - _config: &ElectrumBalancerConfig, - ) -> Result, Error> { - let clients = self.clients.lock().unwrap(); - for client in clients.iter() { - if client.url == url { - return Ok(client.clone()); - } - } - - // If no pre-configured client found, create a default one - Ok(Arc::new(MockElectrumClient::new(url.to_string()))) + url: String, + _request_timeout: Duration, + ) -> BoxFuture<'static, Result> { + let outcome = self.outcomes.get(&url).copied().unwrap_or(MockOutcome::Ok); + let calls = self + .calls + .lock() + .unwrap() + .entry(url.clone()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))) + .clone(); + Box::pin(async move { Ok(MockConnection { url, calls, outcome }) }) } } - fn create_dummy_transaction() -> Transaction { - Transaction { - version: Version::TWO, - lock_time: LockTime::ZERO, - input: vec![TxIn { - previous_output: OutPoint::null(), - script_sig: ScriptBuf::new(), - sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, - witness: Witness::new(), - }], - output: vec![TxOut { - value: Amount::from_sat(1000), - script_pubkey: ScriptBuf::new(), - }], + fn fast_config() -> ElectrumBalancerConfig { + ElectrumBalancerConfig { + request_timeout: Duration::from_secs(1), + min_retries: 0, } } - #[tokio::test] - async fn test_balancer_creation() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - let balancer = ElectrumBalancer::new_with_factory(urls.clone(), factory).await; - - assert!(balancer.is_ok()); - let balancer = balancer.unwrap(); - assert_eq!(balancer.client_count(), 2); - assert_eq!(balancer.urls(), &urls); - } - - #[tokio::test] - async fn test_balancer_empty_urls() { - let factory = Arc::new(MockElectrumClientFactory::new()); - let balancer = ElectrumBalancer::new_with_factory(vec![], factory).await; - + #[tokio::test(start_paused = true)] + async fn empty_urls_is_error() { + let factory = MockFactory::new(vec![]); + let balancer = ElectrumBalancer::new_with_factory(vec![], fast_config(), factory); assert!(balancer.is_err()); - match balancer { - Err(e) => assert!(e.to_string().contains("No Electrum URLs provided")), - Ok(_) => panic!("Expected error but got Ok"), - } } - #[tokio::test] - async fn test_call_sticky_behavior() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - "tcp://localhost:50003".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - for url in &urls { - factory.add_client(MockElectrumClient::new(url.clone())); - } - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - // Make several successful calls and verify sticky behavior (should stay on first client) - for _ in 0..6 { - let result = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; + #[tokio::test(start_paused = true)] + async fn sticky_stays_on_first_server() { + let urls = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let factory = MockFactory::new(vec![ + ("a", MockOutcome::Ok), + ("b", MockOutcome::Ok), + ("c", MockOutcome::Ok), + ]); + let balancer = + ElectrumBalancer::new_with_factory(urls, fast_config(), factory.clone()).unwrap(); + for _ in 0..5 { + let result = balancer.run("test", |c| async move { c.call().await }).await; assert!(result.is_ok()); } - // Verify only the first client was used - assert_eq!(factory.get_client(0).unwrap().call_count(), 6); - assert_eq!(factory.get_client(1).unwrap().call_count(), 0); - assert_eq!(factory.get_client(2).unwrap().call_count(), 0); - } - - #[tokio::test] - async fn test_call_switches_on_failure() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - "tcp://localhost:50003".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - // First client fails, second succeeds, third not used - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::IOError), - ); - factory.add_client(MockElectrumClient::new(urls[1].clone())); - factory.add_client(MockElectrumClient::new(urls[2].clone())); - - // Use config with min_retries = 0 to test basic switching behavior - // This ensures total_attempts = max(0, 3) = 3, but behavior is cleaner - let config = ElectrumBalancerConfig { - request_timeout: 5, - min_retries: 0, - }; - - let balancer = ElectrumBalancer::new_with_config_and_factory(urls, config, factory.clone()) - .await - .unwrap(); - - // First call should try client 0 (fails), then client 1 (succeeds) - let result1 = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - assert!(result1.is_ok()); - - // Second call should also try client 0 first (fails), then client 1 (succeeds) - let result2 = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - assert!(result2.is_ok()); - - // Verify call counts: - // Both calls try client 0 first (fails both times), then client 1 (succeeds both times) - assert_eq!(factory.get_client(0).unwrap().call_count(), 2); // Called on both attempts - assert_eq!(factory.get_client(1).unwrap().call_count(), 2); // Called on both attempts after client 0 fails - assert_eq!(factory.get_client(2).unwrap().call_count(), 0); // Never called - } - - #[tokio::test] - async fn test_call_with_failing_client() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - // First client fails, second succeeds - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::IOError), - ); - factory.add_client(MockElectrumClient::new(urls[1].clone())); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - let result = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - - assert!(result.is_ok()); - - // Verify the failing client was called once and the successful client was called once - assert_eq!(factory.get_client(0).unwrap().call_count(), 1); - assert_eq!(factory.get_client(1).unwrap().call_count(), 1); + assert_eq!(factory.call_count("a"), 5); + assert_eq!(factory.call_count("b"), 0); + assert_eq!(factory.call_count("c"), 0); } - #[tokio::test] - async fn test_call_with_non_retryable_error() { - let urls = vec!["tcp://localhost:50001".to_string()]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::NonRetryable), - ); - - // Use a config with min_retries = 1 to test non-retryable behavior - let config = ElectrumBalancerConfig { - request_timeout: 5, - min_retries: 1, - }; - - let balancer = ElectrumBalancer::new_with_config_and_factory(urls, config, factory.clone()) - .await - .unwrap(); - - let result = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - - assert!(result.is_err()); - match result { - Err(e) => assert!(e.to_string().contains("transaction not found")), - Ok(_) => panic!("Expected error but got Ok"), - } - - // Should only be called once (no retry for non-retryable errors) - assert_eq!(factory.get_client(0).unwrap().call_count(), 1); - } - - #[tokio::test] - async fn test_call_all_clients_fail() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::IOError), - ); - factory.add_client( - MockElectrumClient::new(urls[1].clone()).with_failure(MockErrorType::IOError), - ); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - let result = balancer - .call("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - - assert!(result.is_err()); - match result { - Err(e) => { - let error_msg = e.to_string(); - println!("Error message: {error_msg}"); - assert!( - error_msg.contains("All Electrum nodes failed") - || error_msg.contains("Mock connection failed") - ); - } - Ok(_) => panic!("Expected error but got Ok"), - } + #[tokio::test(start_paused = true)] + async fn fails_over_on_error() { + let urls = vec!["a".to_string(), "b".to_string()]; + let factory = MockFactory::new(vec![ + ("a", MockOutcome::ConnectionError), + ("b", MockOutcome::Ok), + ]); + let balancer = + ElectrumBalancer::new_with_factory(urls, fast_config(), factory.clone()).unwrap(); - // Both clients should have been tried multiple times due to min_retries - assert!(factory.get_client(0).unwrap().call_count() > 1); - assert!(factory.get_client(1).unwrap().call_count() > 1); + let result = balancer.run("test", |c| async move { c.call().await }).await; + assert_eq!(result.unwrap(), "b"); + assert_eq!(factory.call_count("a"), 1); + assert_eq!(factory.call_count("b"), 1); } - #[tokio::test] - async fn test_join_all() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - "tcp://localhost:50003".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client(MockElectrumClient::new(urls[0].clone())); - factory.add_client( - MockElectrumClient::new(urls[1].clone()).with_failure(MockErrorType::IOError), - ); - factory.add_client(MockElectrumClient::new(urls[2].clone())); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - let results = balancer - .join_all("transaction_broadcast", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; + #[tokio::test(start_paused = true)] + async fn all_fail_yields_multi_error() { + let urls = vec!["a".to_string(), "b".to_string()]; + let factory = MockFactory::new(vec![ + ("a", MockOutcome::ResponseError), + ("b", MockOutcome::ResponseError), + ]); + let balancer = + ElectrumBalancer::new_with_factory(urls, fast_config(), factory).unwrap(); + + let result = balancer.run("test", |c| async move { c.call().await }).await; + let err = result.unwrap_err(); + assert!(err.len() >= 2); + assert!(err.any(|e| e.response_json().is_some_and(|j| j.contains("-5")))); + } + + #[tokio::test(start_paused = true)] + async fn join_all_hits_every_server() { + let urls = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let factory = MockFactory::new(vec![ + ("a", MockOutcome::Ok), + ("b", MockOutcome::ConnectionError), + ("c", MockOutcome::Ok), + ]); + let balancer = + ElectrumBalancer::new_with_factory(urls, fast_config(), factory.clone()).unwrap(); - assert!(results.is_ok()); - let results = results.unwrap(); + let results = balancer.join_all("test", |c| async move { c.call().await }).await; assert_eq!(results.len(), 3); - - // First and third should succeed, second should fail assert!(results[0].is_ok()); assert!(results[1].is_err()); assert!(results[2].is_ok()); - - // All clients should have been called - assert_eq!(factory.get_client(0).unwrap().call_count(), 1); - assert_eq!(factory.get_client(1).unwrap().call_count(), 1); - assert_eq!(factory.get_client(2).unwrap().call_count(), 1); - } - - #[tokio::test] - async fn test_broadcast_all() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client(MockElectrumClient::new(urls[0].clone())); - factory.add_client(MockElectrumClient::new(urls[1].clone())); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - let tx = create_dummy_transaction(); - let results = balancer.broadcast_all(tx).await; - - assert!(results.is_ok()); - let results = results.unwrap(); - assert_eq!(results.len(), 2); - - // Both should succeed - assert!(results[0].is_ok()); - assert!(results[1].is_ok()); - - // Both clients should have been called - assert_eq!(factory.get_client(0).unwrap().call_count(), 1); - assert_eq!(factory.get_client(1).unwrap().call_count(), 1); - } - - #[tokio::test] - async fn test_config_and_urls_accessors() { - let urls = vec!["tcp://localhost:50001".to_string()]; - let config = ElectrumBalancerConfig { - request_timeout: 15, - min_retries: 7, - }; - - let factory = Arc::new(MockElectrumClientFactory::new()); - let balancer = - ElectrumBalancer::new_with_config_and_factory(urls.clone(), config.clone(), factory) - .await - .unwrap(); - - assert_eq!(balancer.urls(), &urls); - assert_eq!(balancer.config().request_timeout, 15); - assert_eq!(balancer.config().min_retries, 7); - } - - #[tokio::test] - async fn test_populate_tx_cache() { - let urls = vec!["tcp://localhost:50001".to_string()]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client(MockElectrumClient::new(urls[0].clone())); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - // Initialize the client first - let _ = balancer.call("test", |client| Ok(client.url.clone())).await; - - // This should not panic (MockElectrumClient has default implementation) - let txs = vec![create_dummy_transaction()]; - balancer.populate_tx_cache(txs); - } - - #[tokio::test] - async fn test_multi_error_functionality() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - "tcp://localhost:50003".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::IOError), - ); - factory.add_client( - MockElectrumClient::new(urls[1].clone()).with_failure(MockErrorType::NonRetryable), - ); - factory.add_client( - MockElectrumClient::new(urls[2].clone()).with_failure(MockErrorType::IOError), - ); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - // Use call_async_with_multi_error to get the MultiError - let result = balancer - .call_async_with_multi_error("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - - assert!(result.is_err()); - let multi_error = result.unwrap_err(); - - // Check that we have multiple errors - assert!(multi_error.len() > 1); - assert!(!multi_error.is_empty()); - - // Check that we can inspect individual errors - let error_count = multi_error.errors.len(); - assert!(error_count > 0); - - // Test the `any` method to find specific error types - let has_non_retryable = - multi_error.any(|e| e.to_string().contains("transaction not found")); - assert!(has_non_retryable); - - // Test converting to single error (should work with ?) - let single_error: Error = multi_error.clone().into(); - assert!(!single_error.to_string().is_empty()); - - // Test that the ? operator works - fn test_question_mark(multi_error: MultiError) -> Result<(), Error> { - Err(multi_error)? - } - - let result = test_question_mark(multi_error); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_call_async_with_multi_error() { - let urls = vec![ - "tcp://localhost:50001".to_string(), - "tcp://localhost:50002".to_string(), - ]; - - let factory = Arc::new(MockElectrumClientFactory::new()); - factory.add_client( - MockElectrumClient::new(urls[0].clone()).with_failure(MockErrorType::NonRetryable), - ); - factory.add_client( - MockElectrumClient::new(urls[1].clone()).with_failure(MockErrorType::IOError), - ); - - let balancer = ElectrumBalancer::new_with_factory(urls, factory.clone()) - .await - .unwrap(); - - let result = balancer - .call_async_with_multi_error("test", |client| { - client.transaction_broadcast(&create_dummy_transaction()) - }) - .await; - - assert!(result.is_err()); - let multi_error = result.unwrap_err(); - - // Should have multiple errors due to retries (min_retries = 5, with 2 clients) - assert!(multi_error.len() > 2); - - // Check that there are "transaction not found" type errors - let has_not_found = multi_error.any(|e| e.to_string().contains("transaction not found")); - assert!(has_not_found); - - // And I/O errors - let has_io_error = multi_error.any(|e| e.to_string().contains("Mock connection failed")); - assert!(has_io_error); + assert_eq!(factory.call_count("a"), 1); + assert_eq!(factory.call_count("b"), 1); + assert_eq!(factory.call_count("c"), 1); } }