From b2f2e70913676a10a4a999d9b028202d3d34f87f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 19 Feb 2026 14:21:41 -0500 Subject: [PATCH 1/3] Pull out encryption and shared code --- .cargo/config.toml | 2 + .clippy.toml | 2 + .github/workflows/ci.yml | 16 +- .rustfmt.toml | 5 + Cargo.toml | 42 +- README.md | 4 +- benches/pipe.rs | 70 ++-- benches/throughput.rs | 108 ++--- examples-nodejs/run.js | 3 +- examples/replication.rs | 102 ++--- src/builder.rs | 44 -- src/channels.rs | 68 +-- src/constants.rs | 7 - src/crypto/cipher.rs | 186 --------- src/crypto/curve.rs | 101 ----- src/crypto/handshake.rs | 216 +++------- src/crypto/mod.rs | 8 +- src/duplex.rs | 65 --- src/error.rs | 25 ++ src/lib.rs | 122 +----- src/message.rs | 870 ++++++++++++++++++--------------------- src/mqueue.rs | 150 +++++++ src/protocol.rs | 467 +++++++-------------- src/reader.rs | 231 ----------- src/schema.rs | 655 +++++++++++++++-------------- src/stream.rs | 108 +++++ src/test_utils.rs | 204 +++++++++ src/util.rs | 40 +- src/writer.rs | 173 -------- tests/_util.rs | 100 ++--- tests/basic.rs | 74 ++-- tests/js/.gitignore | 2 + tests/js/mod.rs | 37 +- tests/js/package.json | 3 +- tests/js_interop.rs | 239 +++++------ 35 files changed, 1866 insertions(+), 2683 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 .clippy.toml create mode 100644 .rustfmt.toml delete mode 100644 src/builder.rs delete mode 100644 src/crypto/cipher.rs delete mode 100644 src/crypto/curve.rs delete mode 100644 src/duplex.rs create mode 100644 src/error.rs create mode 100644 src/mqueue.rs delete mode 100644 src/reader.rs create mode 100644 src/stream.rs create mode 100644 src/test_utils.rs delete mode 100644 src/writer.rs create mode 100644 tests/js/.gitignore diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..2e07606 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.wasm32-unknown-unknown] +rustflags = ['--cfg', 'getrandom_backend="wasm_js"'] diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..b6bb929 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1,2 @@ +# We want to know ASAP if we have a needless_pass_by_mut or needless_pass_by_value, etc added to the public API +avoid-breaking-exported-api = false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b58d59..f8589d1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,11 +31,9 @@ jobs: - name: Run tests run: | cargo check --all-targets - cargo check --all-targets --no-default-features --features tokio - cargo check --all-targets --no-default-features --features async-std - cargo test --features js_interop_tests - cargo test --no-default-features --features js_interop_tests,tokio - cargo test --no-default-features --features js_interop_tests,async-std + cargo check --all-targets --no-default-features + cargo test --features js_tests + cargo test --no-default-features --features js_tests cargo test --benches build-extra: @@ -48,15 +46,13 @@ jobs: targets: wasm32-unknown-unknown - name: Build WASM run: | - cargo build --target=wasm32-unknown-unknown --no-default-features --features wasm-bindgen,tokio - cargo build --target=wasm32-unknown-unknown --no-default-features --features wasm-bindgen,async-std + cargo build --target=wasm32-unknown-unknown --no-default-features --features wasm-bindgen - name: Build release run: | - cargo build --release --no-default-features --features tokio - cargo build --release --no-default-features --features async-std + cargo build --release --no-default-features - name: Build examples run: | - cargo build --example replication + cargo build --example replication lint: runs-on: ubuntu-latest diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..adbe5db --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,5 @@ +# groups 'use' statements by crate +imports_granularity = "crate" +# formats code within doc tests +# requires: cargo +nightly fmt (otherwise rustfmt will warn, but pass) +format_code_in_doc_comments = true diff --git a/Cargo.toml b/Cargo.toml index d77679f..1a88c75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = [ documentation = "https://docs.rs/hypercore-protocol" repository = "https://github.com/datrs/hypercore-protocol-rs" readme = "README.md" -edition = "2021" +edition = "2024" keywords = ["dat", "p2p", "replication", "hypercore", "protocol"] categories = [ "asynchronous", @@ -26,12 +26,10 @@ bench = false [dependencies] async-channel = "1" -snow = { version = "0.9", features = ["risky-raw-split"] } -bytes = "1" +snow = { version = "0.10", features = ["risky-raw-split"] } rand = "0.8" blake2 = "0.10" hex = "0.4" -async-trait = "0.1" tracing = "0.1" pretty-hash = "0.4" futures-timer = "3" @@ -39,17 +37,22 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" +futures = "0.3.31" +compact-encoding = "2" +thiserror = "2.0.12" +hypercore_handshake = "0.6.0" -[dependencies.hypercore] -version = "0.14.0" -default-features = false +[dev-dependencies.hypercore] +features = ["shared-core"] +version = "0.16.0" +[dependencies.hypercore_schema] +version = "0.2.0" [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" tokio = { version = "1.27.0", features = ["macros", "net", "process", "rt", "rt-multi-thread", "sync", "time"] } -env_logger = "0.7.1" anyhow = "1.0.28" instant = "0.1" criterion = { version = "0.4", features = ["async_std"] } @@ -57,23 +60,26 @@ pretty-bytes = "0.2.2" duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" -log = "0.4" -test-log = { version = "0.2.11", default-features = false, features = ["trace"] } -tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } +tracing-tree = "0.4.0" +tokio-util = { version = "0.7.14", features = ["compat"] } +uint24le_framing = { version = "0.2.0" } + +[dev-dependencies.rusty_nodejs_repl] +version = "0.4.0" +features = ["serde", "integration_utils"] [features] -default = ["tokio", "sparse"] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] -sparse = ["hypercore/sparse"] -cache = ["hypercore/cache"] -tokio = ["hypercore/tokio"] -async-std = ["hypercore/async-std"] # Used only in interoperability tests under tests/js-interop which use the javascript version of hypercore # to verify that this crate works. To run them, use: -# cargo test --features js_interop_tests -js_interop_tests = [] +# cargo test --features js_tests +js_tests = [] + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.3", features = ["wasm_js"] } [profile.bench] # debug = true diff --git a/README.md b/README.md index b8ed180..fada9df 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ node examples-nodejs/run.js node ## Development -To test interoperability with Javascript, enable the `js_interop_tests` feature: +To test interoperability with Javascript, enable the `js_tests` feature: ```bash -cargo test --features js_interop_tests +cargo test --features js_tests ``` Run benches with: diff --git a/benches/pipe.rs b/benches/pipe.rs index 630146c..5ab2755 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,24 +1,27 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::task; -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::io::{AsyncRead, AsyncWrite}; +#[path = "../tests/_util.rs"] +mod _util; +use criterion::{Criterion, Throughput, criterion_group, criterion_main}; use futures::stream::StreamExt; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; -use log::*; +use hypercore_protocol::{Channel, Event, Message, Protocol, schema::*}; +use hypercore_schema::DataBlock; use pretty_bytes::converter::convert as pretty_bytes; -use sluice::pipe::pipe; -use std::io::Result; -use std::time::Instant; +use std::{io::Result, time::Instant}; +use tracing::{debug, error}; + +use crate::_util::create_pair; const COUNT: u64 = 1000; const SIZE: u64 = 100; const CONNS: u64 = 10; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); - group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS as u64)); + group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); group.bench_function("pipe_echo", |b| { b.iter(|| { task::block_on(async move { @@ -38,17 +41,7 @@ criterion_group!(benches, bench_throughput); criterion_main!(benches); async fn run_echo(i: u64) -> Result<()> { - // let cap: usize = SIZE as usize * 10; - let (ar, bw) = pipe(); - let (br, aw) = pipe(); - - let encrypted = true; - let a = ProtocolBuilder::new(true) - .encrypted(encrypted) - .connect_rw(ar, aw); - let b = ProtocolBuilder::new(false) - .encrypted(encrypted) - .connect_rw(br, bw); + let (a, b) = create_pair(); let ta = task::spawn(async move { onconnection(i, a).await }); let tb = task::spawn(async move { onconnection(i, b).await }); ta.await?; @@ -58,11 +51,7 @@ async fn run_echo(i: u64) -> Result<()> { // The onconnection handler is called for each incoming connection (if server) // or once when connected (if client). -async fn onconnection(i: u64, mut protocol: Protocol>) -> Result -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ +async fn onconnection(i: u64, mut protocol: Protocol) -> Result { let key = [0u8; 32]; let is_initiator = protocol.is_initiator(); // let mut len: u64 = 0; @@ -72,7 +61,7 @@ where debug!("[{}] EVENT {:?}", is_initiator, event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await?; + protocol.open(key).await?; } Event::DiscoveryKey(_dkey) => {} Event::Channel(channel) => { @@ -92,7 +81,7 @@ where } Some(Err(err)) => { error!("ERROR {:?}", err); - return Err(err.into()); + return Err(err); } None => return Ok(0), } @@ -127,20 +116,17 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { let start = std::time::Instant::now(); while let Some(message) = channel.next().await { - match message { - Message::Data(mut data) => { - len += value_len(&data); - debug!("[a] recv {}", index(&data)); - if index(&data) >= COUNT { - debug!("close at {}", index(&data)); - channel.close().await?; - break; - } else { - increment_index(&mut data); - channel.send(Message::Data(data)).await?; - } + if let Message::Data(mut data) = message { + len += value_len(&data); + debug!("[a] recv {}", index(&data)); + if index(&data) >= COUNT { + debug!("close at {}", index(&data)); + channel.close().await?; + break; + } else { + increment_index(&mut data); + channel.send(Message::Data(data)).await?; } - _ => {} } } // let bytes = (COUNT * SIZE) as f64; @@ -149,8 +135,6 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { } fn msg_data(index: u64, value: Vec) -> Message { - use hypercore::DataBlock; - Message::Data(Data { request: index, fork: 0, diff --git a/benches/throughput.rs b/benches/throughput.rs index 76d6874..3b1f3e9 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,13 +1,23 @@ -use async_std::net::{Shutdown, TcpListener, TcpStream}; -use async_std::task; -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::future::Either; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; -use log::*; +#[path = "../src/test_utils.rs"] +mod test_utils; +use async_std::{ + net::{TcpListener, TcpStream}, + task, +}; +use criterion::{Criterion, Throughput, criterion_group, criterion_main}; +use futures::{ + future::Either, + stream::{FuturesUnordered, StreamExt}, +}; +use hypercore_handshake::{ + Cipher, + state_machine::{SecStream, hc_specific::generate_keypair}, +}; +use hypercore_protocol::{Channel, Event, Message, Protocol, schema::*}; +use hypercore_schema::DataBlock; use std::time::Instant; +use tracing::{debug, info, trace}; +use uint24le_framing::Uint24LELengthPrefixedFraming; const PORT: usize = 11011; const SIZE: u64 = 1000; @@ -15,8 +25,8 @@ const COUNT: u64 = 200; const CLIENTS: usize = 1; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); - let address = format!("localhost:{}", PORT); + test_utils::log(); + let address = format!("localhost:{PORT}"); let mut group = c.benchmark_group("throughput"); let data = vec![1u8; SIZE as usize]; @@ -45,8 +55,7 @@ fn bench_throughput(c: &mut Criterion) { let mut futures: FuturesUnordered<_> = streams .into_iter() .map(|s| async move { - onconnection(s.clone(), s.clone(), true).await; - s.shutdown(Shutdown::Both) + onconnection(s, true).await; }) .collect(); while let Some(_res) = futures.next().await {} @@ -64,23 +73,22 @@ criterion_main!(server_benches); async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { let listener = TcpListener::bind(&address).await.unwrap(); - log::info!("listening on {}", listener.local_addr().unwrap()); + info!("listening on {}", listener.local_addr().unwrap()); let (kill_tx, mut kill_rx) = futures::channel::oneshot::channel(); task::spawn(async move { let mut incoming = listener.incoming(); // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => match next { - Some(Ok(stream)) => { + Either::Left((next, _)) => { + if let Some(Ok(stream)) = next { let peer_addr = stream.peer_addr().unwrap(); debug!("new connection from {}", peer_addr); task::spawn(async move { - onconnection(stream.clone(), stream, false).await; + onconnection(stream.clone(), false).await; }); } - _ => {} - }, + } Either::Right((_, _)) => return, } } @@ -88,32 +96,34 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { kill_tx } -async fn onconnection(reader: R, writer: W, is_initiator: bool) -> Duplex -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ +async fn onconnection(reader: TcpStream, is_initiator: bool) { let key = [0u8; 32]; - let mut protocol = ProtocolBuilder::new(is_initiator) - .encrypted(false) - .connect_rw(reader, writer); + let framed = Uint24LELengthPrefixedFraming::new(reader); + let cipher = if is_initiator { + let ss = SecStream::new_initiator_xx(&[]).unwrap(); + Cipher::new(Some(Box::new(framed)), ss.into()) + } else { + let keypair = generate_keypair().unwrap(); + let ss = SecStream::new_responder_xx(&keypair, &[]).unwrap(); + Cipher::new(Some(Box::new(framed)), ss.into()) + }; + let mut protocol = Protocol::new(Box::new(cipher)); while let Some(Ok(event)) = protocol.next().await { // eprintln!("RECV EVENT [{}] {:?}", protocol.is_initiator(), event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await.unwrap(); + protocol.open(key).await.unwrap(); } Event::DiscoveryKey(_) => {} Event::Channel(channel) => { task::spawn(onchannel(channel, is_initiator)); } Event::Close(_dkey) => { - return protocol.release(); + return; } _ => {} } } - protocol.release() } async fn onchannel(mut channel: Channel, is_initiator: bool) { @@ -127,9 +137,8 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - match message { - Message::Data(_) => channel.send(message).await.unwrap(), - _ => {} + if let Message::Data(_) = message { + channel.send(message).await.unwrap() } } } @@ -140,31 +149,26 @@ async fn channel_client(channel: &mut Channel) { let message = msg_data(0, data.clone()); channel.send(message).await.unwrap(); while let Some(message) = channel.next().await { - match message { - Message::Data(ref msg) => { - if index(msg) < COUNT { - let message = msg_data(index(msg) + 1, data.clone()); - channel.send(message).await.unwrap(); - } else { - let time = start.elapsed(); - let bytes = COUNT * SIZE; - trace!( - "client completed. {} blocks, {} bytes, {:?}", - index(msg), - bytes, - time - ); - break; - } + if let Message::Data(ref msg) = message { + if index(msg) < COUNT { + let message = msg_data(index(msg) + 1, data.clone()); + channel.send(message).await.unwrap(); + } else { + let time = start.elapsed(); + let bytes = COUNT * SIZE; + trace!( + "client completed. {} blocks, {} bytes, {:?}", + index(msg), + bytes, + time + ); + break; } - _ => {} } } } fn msg_data(index: u64, value: Vec) -> Message { - use hypercore::DataBlock; - Message::Data(Data { request: index, fork: 0, diff --git a/examples-nodejs/run.js b/examples-nodejs/run.js index c96541f..ac77bba 100644 --- a/examples-nodejs/run.js +++ b/examples-nodejs/run.js @@ -37,7 +37,8 @@ function startRust (mode, key, color, name) { color: color || 'blue', env: { ...process.env, - RUST_LOG_STYLE: 'always' + RUST_LOG_STYLE: 'always', + RUST_LOG: 'trace' } }) return rust diff --git a/examples/replication.rs b/examples/replication.rs index bf65b72..cadc613 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,25 +1,28 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use anyhow::Result; -use async_std::net::{TcpListener, TcpStream}; -use async_std::prelude::*; -use async_std::sync::{Arc, Mutex}; -use async_std::task; -use env_logger::Env; +use async_std::{ + net::{TcpListener, TcpStream}, + prelude::*, + sync::{Arc, Mutex}, + task, +}; use futures_lite::stream::StreamExt; -use hypercore::{ - Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, - VerifyingKey, +use hypercore::{Hypercore, HypercoreBuilder, PartialKeypair, Storage, VerifyingKey}; + +use hypercore_handshake::{ + Cipher, + state_machine::{SecStream, hc_specific::generate_keypair}, }; -use log::*; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; -use std::fmt::Debug; +use hypercore_schema::{RequestBlock, RequestUpgrade}; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug}; +use tracing::{error, info, instrument}; +use uint24le_framing::Uint24LELengthPrefixedFraming; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{Channel, Event, Message, Protocol, discovery_key, schema::*}; fn main() { - init_logger(); + test_utils::log(); if env::args().count() < 3 { usage(); } @@ -65,12 +68,11 @@ fn main() { hypercore_store.add(hypercore_wrapper); let hypercore_store = Arc::new(hypercore_store); - let result = match mode.as_ref() { + let _ = match mode.as_ref() { "server" => tcp_server(address, onconnection, hypercore_store).await, "client" => tcp_client(address, onconnection, hypercore_store).await, _ => panic!("{:?}", usage()), }; - log_if_error(&result); }); } @@ -84,17 +86,28 @@ fn usage() { // or once when connected (if client). // Unfortunately, everything that touches the hypercore_store or a hypercore has to be generic // at the moment. +#[instrument(skip_all, ret)] async fn onconnection( stream: TcpStream, is_initiator: bool, hypercore_store: Arc, ) -> Result<()> { info!("onconnection, initiator: {}", is_initiator); - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); + + let framed = Uint24LELengthPrefixedFraming::new(stream); + let cipher = if is_initiator { + let ss = SecStream::new_initiator_xx(&[])?; + Cipher::new(Some(Box::new(framed)), ss.into()) + } else { + let keypair = generate_keypair().unwrap(); + let ss = SecStream::new_responder_xx(&keypair, &[])?; + Cipher::new(Some(Box::new(framed)), ss.into()) + }; + let mut protocol = Protocol::new(Box::new(cipher)); info!("protocol created, polling for next()"); while let Some(event) = protocol.next().await { - let event = event?; info!("protocol event {:?}", event); + let event = event?; match event { Event::Handshake(_) => { if is_initiator { @@ -126,17 +139,17 @@ struct HypercoreStore { hypercores: HashMap>, } impl HypercoreStore { - pub fn new() -> Self { + fn new() -> Self { let hypercores = HashMap::new(); Self { hypercores } } - pub fn add(&mut self, hypercore: HypercoreWrapper) { + fn add(&mut self, hypercore: HypercoreWrapper) { let hdkey = hex::encode(hypercore.discovery_key); self.hypercores.insert(hdkey, Arc::new(hypercore)); } - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { + fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { let hdkey = hex::encode(discovery_key); self.hypercores.get(&hdkey) } @@ -151,7 +164,7 @@ struct HypercoreWrapper { } impl HypercoreWrapper { - pub fn from_memory_hypercore(hypercore: Hypercore) -> Self { + fn from_memory_hypercore(hypercore: Hypercore) -> Self { let key = hypercore.key_pair().public.to_bytes(); HypercoreWrapper { key, @@ -160,11 +173,11 @@ impl HypercoreWrapper { } } - pub fn key(&self) -> &[u8; 32] { + fn key(&self) -> &[u8; 32] { &self.key } - pub fn onpeer(&self, mut channel: Channel) { + fn onpeer(&self, mut channel: Channel) { let mut peer_state = PeerState::default(); let mut hypercore = self.hypercore.clone(); task::spawn(async move { @@ -299,6 +312,8 @@ async fn onmessage( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -366,7 +381,9 @@ async fn onmessage( println!(); println!("### Results"); println!(); - println!("Replication succeeded if this prints '0: hi', '1: ola', '2: hello' and '3: mundo':"); + println!( + "Replication succeeded if this prints '0: hi', '1: ola', '2: hello' and '3: mundo':" + ); println!(); for i in 0..new_info.contiguous_length { println!( @@ -405,6 +422,8 @@ async fn onmessage( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } channel.send_batch(&messages).await.unwrap(); @@ -414,20 +433,9 @@ async fn onmessage( Ok(()) } -/// Init EnvLogger, logging info, warn and error messages to stdout. -pub fn init_logger() { - env_logger::from_env(Env::default().default_filter_or("info")).init(); -} - -/// Log a result if it's an error. -pub fn log_if_error(result: &Result<()>) { - if let Err(err) = result.as_ref() { - log::error!("error: {}", err); - } -} - /// A simple async TCP server that calls an async function for each incoming connection. -pub async fn tcp_server( +#[instrument(skip_all, ret)] +async fn tcp_server( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, context: C, @@ -437,22 +445,22 @@ where C: Clone + Send + 'static, { let listener = TcpListener::bind(&address).await?; - log::info!("listening on {}", listener.local_addr()?); + tracing::info!("listening on {}", listener.local_addr()?); let mut incoming = listener.incoming(); while let Some(Ok(stream)) = incoming.next().await { let context = context.clone(); let peer_addr = stream.peer_addr().unwrap(); - log::info!("new connection from {}", peer_addr); + tracing::info!("new connection from {}", peer_addr); task::spawn(async move { - let result = onconnection(stream, false, context).await; - log_if_error(&result); - log::info!("connection closed from {}", peer_addr); + let _ = onconnection(stream, false, context).await; + tracing::info!("connection closed from {}", peer_addr); }); } Ok(()) } /// A simple async TCP client that calls an async function when connected. +#[instrument(skip_all, ret)] pub async fn tcp_client( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, @@ -462,8 +470,8 @@ where F: Future> + Send, C: Clone + Send + 'static, { - log::info!("attempting connection to {address}"); + tracing::info!("attempting connection to {address}"); let stream = TcpStream::connect(&address).await?; - log::info!("connected to {address}"); + tracing::info!("connected to {address}"); onconnection(stream, true, context).await } diff --git a/src/builder.rs b/src/builder.rs deleted file mode 100644 index d797654..0000000 --- a/src/builder.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::Protocol; -use crate::{duplex::Duplex, protocol::Options}; -use futures_lite::io::{AsyncRead, AsyncWrite}; - -/// Build a Protocol instance with options. -#[derive(Debug)] -pub struct Builder(Options); - -impl Builder { - /// Create a protocol builder as initiator (true) or responder (false). - pub fn new(initiator: bool) -> Self { - Self(Options::new(initiator)) - } - - /// Set encrypted option. Defaults to true. - pub fn encrypted(mut self, encrypted: bool) -> Self { - self.0.encrypted = encrypted; - self - } - - /// Set handshake option. Defaults to true. - pub fn handshake(mut self, handshake: bool) -> Self { - self.0.noise = handshake; - self - } - - /// Create the protocol from a stream that implements AsyncRead + AsyncWrite + Clone. - pub fn connect(self, io: IO) -> Protocol - where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - Protocol::new(io, self.0) - } - - /// Create the protocol from an AsyncRead reader and AsyncWrite writer. - pub fn connect_rw(self, reader: R, writer: W) -> Protocol> - where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, - { - let io = Duplex::new(reader, writer); - Protocol::new(io, self.0) - } -} diff --git a/src/channels.rs b/src/channels.rs index c2e22f8..ed8904d 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,19 +1,23 @@ -use crate::message::ChannelMessage; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::Message; -use crate::{discovery_key, DiscoveryKey, Key}; +use crate::{ + DiscoveryKey, Key, Message, discovery_key, + message::ChannelMessage, + schema::*, + util::{map_channel_err, pretty_hash}, +}; use async_channel::{Receiver, Sender, TrySendError}; -use futures_lite::ready; -use futures_lite::stream::Stream; -use std::collections::HashMap; -use std::fmt; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::Poll; -use tracing::debug; +use futures_lite::{ready, stream::Stream}; +use std::{ + collections::HashMap, + fmt, + io::{Error, ErrorKind, Result}, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + task::Poll, +}; +use tracing::instrument; /// A protocol channel. /// @@ -86,14 +90,13 @@ impl Channel { } /// Send a message over the channel. - pub async fn send(&mut self, message: Message) -> Result<()> { + pub async fn send(&self, message: Message) -> Result<()> { if self.closed() { return Err(Error::new( ErrorKind::ConnectionAborted, "Channel is closed", )); } - debug!("TX:\n{message:?}\n"); let message = ChannelMessage::new(self.local_id as u64, message); self.outbound_tx .send(vec![message]) @@ -102,7 +105,7 @@ impl Channel { } /// Send a batch of messages over the channel. - pub async fn send_batch(&mut self, messages: &[Message]) -> Result<()> { + pub async fn send_batch(&self, messages: &[Message]) -> Result<()> { // In javascript this is cork()/uncork(), e.g.: // // https://github.com/holepunchto/hypercore/blob/c338b9aaa4442d35bc9d283d2c242b86a46de6d4/lib/replicator.js#L402-L418 @@ -122,11 +125,9 @@ impl Channel { let messages = messages .iter() - .map(|message| { - debug!("TX:\n{message:?}\n"); - ChannelMessage::new(self.local_id as u64, message.clone()) - }) + .map(|message| ChannelMessage::new(self.local_id as u64, message.clone())) .collect(); + self.outbound_tx .send(messages) .await @@ -151,7 +152,7 @@ impl Channel { } /// Send a close message and close this channel. - pub async fn close(&mut self) -> Result<()> { + pub async fn close(&self) -> Result<()> { if self.closed() { return Ok(()); } @@ -165,7 +166,7 @@ impl Channel { /// Signal the protocol to produce Event::LocalSignal. If you want to send a message /// to the channel level, see take_receiver() and local_sender(). - pub async fn signal_local_protocol(&mut self, name: &str, data: Vec) -> Result<()> { + pub async fn signal_local_protocol(&self, name: &str, data: Vec) -> Result<()> { self.send(Message::LocalSignal((name.to_string(), data))) .await?; Ok(()) @@ -249,6 +250,7 @@ impl ChannelHandle { self.remote_state.as_ref().map(|s| s.remote_id) } + #[instrument(skip_all, fields(local_id = local_id))] pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) { let local_state = LocalState { local_id, key }; self.local_state = Some(local_state); @@ -276,6 +278,7 @@ impl ChannelHandle { Ok((&local_state.key, remote_state.remote_capability.as_ref())) } + #[instrument(skip_all)] pub(crate) fn open(&mut self, outbound_tx: Sender>) -> Channel { let local_state = self .local_state @@ -311,14 +314,14 @@ impl ChannelHandle { &mut self, message: Message, ) -> std::io::Result<()> { - if let Some(inbound_tx) = self.inbound_tx.as_mut() { - if let Err(err) = inbound_tx.try_send(message) { - match err { - TrySendError::Full(e) => { - return Err(error(format!("Sending to channel failed: {e}").as_str())) - } - TrySendError::Closed(_) => {} + if let Some(inbound_tx) = self.inbound_tx.as_mut() + && let Err(err) = inbound_tx.try_send(message) + { + match err { + TrySendError::Full(e) => { + return Err(error(format!("Sending to channel failed: {e}").as_str())); } + TrySendError::Closed(_) => {} } } Ok(()) @@ -433,6 +436,7 @@ impl ChannelMap { self.channels.remove(&hdkey); } + #[instrument(skip(self))] pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { let channel_handle = self .get_local(local_id) @@ -507,5 +511,5 @@ impl ChannelMap { } fn error(message: &str) -> Error { - Error::new(ErrorKind::Other, message) + Error::other(message) } diff --git a/src/constants.rs b/src/constants.rs index 77285ee..1efbbed 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,15 +1,8 @@ /// Seed for the discovery key hash pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; -/// Default timeout (in seconds) -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; - /// Default keepalive interval (in seconds) pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; -// 16,78MB is the max encrypted wire message size (will be much smaller usually). -// This limitation stems from the 24bit header. -pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; - /// v10: Protocol name pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs deleted file mode 100644 index c0e54a9..0000000 --- a/src/crypto/cipher.rs +++ /dev/null @@ -1,186 +0,0 @@ -use super::HandshakeResult; -use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; -use blake2::{ - digest::{typenum::U32, FixedOutput, Update}, - Blake2bMac, -}; -use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; -use rand::rngs::OsRng; -use std::convert::TryInto; -use std::io; - -const STREAM_ID_LENGTH: usize = 32; -const KEY_LENGTH: usize = 32; -const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; - -pub(crate) struct DecryptCipher { - pull_stream: PullStream, -} - -pub(crate) struct EncryptCipher { - push_stream: PushStream, -} - -impl std::fmt::Debug for DecryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DecryptCipher(crypto_secretstream)") - } -} - -impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } -} - -impl DecryptCipher { - pub(crate) fn from_handshake_rx_and_init_msg( - handshake_result: &HandshakeResult, - init_msg: &[u8], - ) -> io::Result { - if init_msg.len() < 32 + 24 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!( - "Received too short init message, {} < {}.", - init_msg.len(), - 32 + 24 - ), - )); - } - - let key: [u8; KEY_LENGTH] = handshake_result.split_rx[..KEY_LENGTH] - .try_into() - .expect("split_rx with incorrect length"); - let key = Key::from(key); - let handshake_hash = handshake_result.handshake_hash.clone(); - let is_initiator = handshake_result.is_initiator; - - // Read the received message from the other peer - let mut expected_stream_id: [u8; 32] = [0; 32]; - write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); - let remote_stream_id: [u8; 32] = init_msg[0..32] - .try_into() - .expect("stream id slice with incorrect length"); - if expected_stream_id != remote_stream_id { - return Err(io::Error::new( - io::ErrorKind::PermissionDenied, - "Received stream id does not match expected".to_string(), - )); - } - - let header: [u8; 24] = init_msg[32..] - .try_into() - .expect("header slice with incorrect length"); - let pull_stream = PullStream::init(Header::from(header), &key); - Ok(Self { pull_stream }) - } - - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = 3 + to_decrypt.len(); - buf[3..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { - let mut to_decrypt = buf.to_vec(); - let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Decrypt failed: {err}")) - })?; - Ok((to_decrypt, *tag)) - } -} - -impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); - - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); - - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } - - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(3 + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) - } - } -} - -// NB: These values come from Javascript-side -// -// const [NS_INITIATOR, NS_RESPONDER] = crypto.namespace('hyperswarm/secret-stream', 2) -// -// at https://github.com/hyperswarm/secret-stream/blob/master/index.js -const NS_INITIATOR: [u8; 32] = [ - 0xa9, 0x31, 0xa0, 0x15, 0x5b, 0x5c, 0x09, 0xe6, 0xd2, 0x86, 0x28, 0x23, 0x6a, 0xf8, 0x3c, 0x4b, - 0x8a, 0x6a, 0xf9, 0xaf, 0x60, 0x98, 0x6e, 0xde, 0xed, 0xe9, 0xdc, 0x5d, 0x63, 0x19, 0x2b, 0xf7, -]; -const NS_RESPONDER: [u8; 32] = [ - 0x74, 0x2c, 0x9d, 0x83, 0x3d, 0x43, 0x0a, 0xf4, 0xc4, 0x8a, 0x87, 0x05, 0xe9, 0x16, 0x31, 0xee, - 0xcf, 0x29, 0x54, 0x42, 0xbb, 0xca, 0x18, 0x99, 0x6e, 0x59, 0x70, 0x97, 0x72, 0x3b, 0x10, 0x61, -]; - -fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { - let mut hasher = - Blake2bMac::::new_with_salt_and_personal(handshake_hash, &[], &[]).unwrap(); - if is_initiator { - hasher.update(&NS_INITIATOR); - } else { - hasher.update(&NS_RESPONDER); - } - let result = hasher.finalize_fixed(); - let result = result.as_slice(); - out.copy_from_slice(result); -} diff --git a/src/crypto/curve.rs b/src/crypto/curve.rs deleted file mode 100644 index 48ed841..0000000 --- a/src/crypto/curve.rs +++ /dev/null @@ -1,101 +0,0 @@ -use hypercore::{generate_signing_key, SecretKey, SigningKey, VerifyingKey}; -use sha2::Digest; -use snow::{ - params::{CipherChoice, DHChoice, HashChoice}, - resolvers::CryptoResolver, - types::{Cipher, Dh, Hash, Random}, -}; -use std::convert::TryInto; - -/// Wraps ed25519-dalek compatible keypair -#[derive(Default)] -struct Ed25519 { - privkey: [u8; 32], - pubkey: [u8; 32], -} - -impl Dh for Ed25519 { - fn name(&self) -> &'static str { - "Ed25519" - } - - fn pub_len(&self) -> usize { - 32 - } - - fn priv_len(&self) -> usize { - 32 - } - - fn set(&mut self, privkey: &[u8]) { - let secret: SecretKey = privkey - .try_into() - .expect("Can't use given bytes as SecretKey"); - let public: VerifyingKey = SigningKey::from(&secret).verifying_key(); - self.privkey[..privkey.len()].copy_from_slice(privkey); - let public_key_bytes = public.as_bytes(); - self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); - } - - fn generate(&mut self, _: &mut dyn Random) { - // NB: Given Random can't be used with ed25519_dalek's SigningKey::generate(), - // use OS's random here from hypercore. - let signing_key = generate_signing_key(); - let secret_key_bytes = signing_key.to_bytes(); - self.privkey[..secret_key_bytes.len()].copy_from_slice(&secret_key_bytes); - let verifying_key = signing_key.verifying_key(); - let public_key_bytes = verifying_key.as_bytes(); - self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); - } - - fn pubkey(&self) -> &[u8] { - &self.pubkey - } - - fn privkey(&self) -> &[u8] { - &self.privkey - } - - fn dh(&self, pubkey: &[u8], out: &mut [u8]) -> Result<(), snow::Error> { - let sk: [u8; 32] = sha2::Sha512::digest(self.privkey).as_slice()[..32] - .try_into() - .unwrap(); - // PublicKey is a CompressedEdwardsY in dalek. So we decompress it to get the - // EdwardsPoint and use variable base multiplication. - let cey = - curve25519_dalek::edwards::CompressedEdwardsY::from_slice(&pubkey[..self.pub_len()]) - .map_err(|_| snow::Error::Dh)?; - let pubkey: curve25519_dalek::edwards::EdwardsPoint = match cey.decompress() { - Some(ep) => Ok(ep), - None => Err(snow::Error::Dh), - }?; - let result = pubkey.mul_clamped(sk); - let result: [u8; 32] = *result.compress().as_bytes(); - out[..result.len()].copy_from_slice(result.as_slice()); - Ok(()) - } -} - -#[derive(Default)] -pub(super) struct CurveResolver; - -impl CryptoResolver for CurveResolver { - fn resolve_dh(&self, choice: &DHChoice) -> Option> { - match *choice { - DHChoice::Curve25519 => Some(Box::::default()), - _ => None, - } - } - - fn resolve_rng(&self) -> Option> { - None - } - - fn resolve_hash(&self, _choice: &HashChoice) -> Option> { - None - } - - fn resolve_cipher(&self, _choice: &CipherChoice) -> Option> { - None - } -} diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 64db407..1b8ec4c 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -1,16 +1,14 @@ -use super::curve::CurveResolver; -use crate::util::wrap_uint24_le; +//! Handshake result and capability verification for hypercore replication. +//! +//! This module handles capability verification using the handshake hash from +//! the underlying encrypted connection (e.g., from hyperswarm/hyperdht). + use blake2::{ - digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, + digest::{FixedOutput, Update, typenum::U32}, }; -use snow::resolvers::{DefaultResolver, FallbackResolver}; -use snow::{Builder, Error as SnowError, HandshakeState}; use std::io::{Error, ErrorKind, Result}; -const CIPHERKEYLEN: usize = 32; -const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; - // These the output of, see `hash_namespace` test below for how they are produced // https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L9 const REPLICATE_INITIATOR: [u8; 32] = [ @@ -22,17 +20,48 @@ const REPLICATE_RESPONDER: [u8; 32] = [ 0x4E, 0x9, 0x26, 0x26, 0x2, 0x56, 0x86, 0x5A, 0xCC, 0xC0, 0xBF, 0x15, 0xBD, 0x79, 0x12, 0x7D, ]; +/// Result of a Noise handshake, used for capability verification. +/// +/// When using hypercore-protocol with hyperswarm, the Noise handshake is performed +/// at the transport layer (hyperdht). This struct holds the information needed +/// for capability verification when opening channels. #[derive(Debug, Clone, Default)] -pub(crate) struct HandshakeResult { +pub struct HandshakeResult { pub(crate) is_initiator: bool, - pub(crate) local_pubkey: Vec, - pub(crate) remote_pubkey: Vec, + /// Local public key (32 bytes) + pub local_pubkey: Vec, + /// Remote public key (32 bytes) + pub remote_pubkey: Vec, pub(crate) handshake_hash: Vec, - pub(crate) split_tx: [u8; CIPHERKEYLEN], - pub(crate) split_rx: [u8; CIPHERKEYLEN], } impl HandshakeResult { + /// Create a HandshakeResult for a pre-encrypted connection. + /// + /// This is used when the Noise handshake was performed at a lower layer + /// (e.g., hyperswarm/hyperdht) and we're reusing the encrypted channel. + /// The handshake_hash is used for capability verification. + /// + /// # Arguments + /// * `is_initiator` - Whether this peer initiated the connection + /// * `local_pubkey` - This peer's Noise public key + /// * `remote_pubkey` - The remote peer's Noise public key + /// * `handshake_hash` - The 64-byte handshake hash from the Noise handshake + pub fn from_pre_encrypted( + is_initiator: bool, + local_pubkey: [u8; 32], + remote_pubkey: [u8; 32], + handshake_hash: Vec, + ) -> Self { + Self { + is_initiator, + local_pubkey: local_pubkey.to_vec(), + remote_pubkey: remote_pubkey.to_vec(), + handshake_hash, + } + } + + /// Compute capability for opening a channel with the given key. pub(crate) fn capability(&self, key: &[u8]) -> Option> { Some(replicate_capability( self.is_initiator, @@ -41,6 +70,7 @@ impl HandshakeResult { )) } + /// Compute expected remote capability for the given key. pub(crate) fn remote_capability(&self, key: &[u8]) -> Option> { Some(replicate_capability( !self.is_initiator, @@ -49,6 +79,7 @@ impl HandshakeResult { )) } + /// Verify a remote peer's capability for opening a channel. pub(crate) fn verify_remote_capability( &self, capability: Option>, @@ -69,161 +100,8 @@ impl HandshakeResult { } } -pub(crate) struct Handshake { - result: HandshakeResult, - state: HandshakeState, - payload: Vec, - tx_buf: Vec, - rx_buf: Vec, - complete: bool, - did_receive: bool, -} - -impl Handshake { - pub(crate) fn new(is_initiator: bool) -> Result { - let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; - - let payload = vec![]; - let result = HandshakeResult { - is_initiator, - local_pubkey, - ..Default::default() - }; - Ok(Self { - state, - result, - payload, - tx_buf: vec![0u8; 512], - rx_buf: vec![0u8; 512], - complete: false, - did_receive: false, - }) - } - - pub(crate) fn start(&mut self) -> Result>> { - if self.is_initiator() { - let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - Ok(Some(wrapped)) - } else { - Ok(None) - } - } - - pub(crate) fn complete(&self) -> bool { - self.complete - } - - pub(crate) fn is_initiator(&self) -> bool { - self.result.is_initiator - } - - fn recv(&mut self, msg: &[u8]) -> Result { - self.state - .read_message(msg, &mut self.rx_buf) - .map_err(map_err) - } - fn send(&mut self) -> Result { - self.state - .write_message(&self.payload, &mut self.tx_buf) - .map_err(map_err) - } - - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { - // eprintln!("hs read len {}", msg.len()); - if self.complete() { - return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); - } - - let _rx_len = self.recv(msg)?; - - if !self.is_initiator() && !self.did_receive { - self.did_receive = true; - let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - return Ok(Some(wrapped)); - } - - let tx_buf = if self.is_initiator() { - let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - Some(wrapped) - } else { - None - }; - - let split = self.state.dangerously_get_raw_split(); - if self.is_initiator() { - self.result.split_tx = split.0; - self.result.split_rx = split.1; - } else { - self.result.split_tx = split.1; - self.result.split_rx = split.0; - } - self.result.remote_pubkey = self - .state - .get_remote_static() - .expect("Could not read remote static key after handshake") - .to_vec(); - self.result.handshake_hash = self.state.get_handshake_hash().to_vec(); - self.complete = true; - Ok(tx_buf) - } - - pub(crate) fn into_result(self) -> Result { - if !self.complete() { - Err(Error::new(ErrorKind::Other, "Handshake is not complete")) - } else { - Ok(self.result) - } - } -} - -fn build_handshake_state( - is_initiator: bool, -) -> std::result::Result<(HandshakeState, Vec), SnowError> { - use snow::params::{ - BaseChoice, CipherChoice, DHChoice, HandshakeChoice, HandshakeModifierList, - HandshakePattern, HashChoice, NoiseParams, - }; - // NB: HANDSHAKE_PATTERN.parse() doesn't work because the pattern has "Ed25519" - // instead of "25519". - let noise_params = NoiseParams::new( - HANDSHAKE_PATTERN.to_string(), - BaseChoice::Noise, - HandshakeChoice { - pattern: HandshakePattern::XX, - modifiers: HandshakeModifierList { list: vec![] }, - }, - DHChoice::Curve25519, - CipherChoice::ChaChaPoly, - HashChoice::Blake2b, - ); - let builder: Builder<'_> = Builder::with_resolver( - noise_params, - Box::new(FallbackResolver::new( - Box::::default(), - Box::::default(), - )), - ); - let key_pair = builder.generate_keypair().unwrap(); - let builder = builder.local_private_key(&key_pair.private); - let handshake_state = if is_initiator { - tracing::debug!("building initiator"); - builder.build_initiator()? - } else { - tracing::debug!("building responder"); - builder.build_responder()? - }; - Ok((handshake_state, key_pair.public)) -} - -fn map_err(e: SnowError) -> Error { - Error::new(ErrorKind::PermissionDenied, format!("Handshake error: {e}")) -} - /// Create a hash used to indicate replication capability. -/// See https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11 +/// See JavaScript [here](https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11). fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) -> Vec { let seed = if is_initiator { REPLICATE_INITIATOR @@ -236,6 +114,6 @@ fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) - hasher.update(&seed); hasher.update(key); let hash = hasher.finalize_fixed(); - let capability = hash.as_slice().to_vec(); - capability + + hash.as_slice().to_vec() } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 66bb62d..54f711f 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,3 @@ -mod cipher; -mod curve; -mod handshake; -pub(crate) use cipher::{DecryptCipher, EncryptCipher}; -pub(crate) use handshake::{Handshake, HandshakeResult}; +/// Handshake result and capability verification. +pub(crate) mod handshake; +pub(crate) use handshake::HandshakeResult; diff --git a/src/duplex.rs b/src/duplex.rs deleted file mode 100644 index fe79c1b..0000000 --- a/src/duplex.rs +++ /dev/null @@ -1,65 +0,0 @@ -use futures_lite::{AsyncRead, AsyncWrite}; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -#[derive(Clone, Debug, PartialEq)] -/// Duplex IO stream from reader and writer halves. -/// -/// Convert an AsyncRead reader and AsyncWrite writer into a -/// AsyncRead + AsyncWrite stream -pub struct Duplex -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ - reader: R, - writer: W, -} - -impl Duplex -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ - /// Create a new Duplex stream from a reader and a writer. - pub fn new(reader: R, writer: W) -> Self { - Self { reader, writer } - } -} - -impl AsyncRead for Duplex -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.reader).poll_read(cx, buf) - } -} - -impl AsyncWrite for Duplex -where - R: AsyncRead + Send + Unpin + 'static, - W: AsyncWrite + Send + Unpin + 'static, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.writer).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.writer).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.writer).poll_close(cx) - } -} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e35da03 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,25 @@ +/// Error type for this crate +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Error from the [`snow`] crate + #[error("Error from `snow`: {0}")] + Snow(#[from] snow::Error), + /// Error from [`crypto_secretstream`] crate + #[error("Error from `crypto_secretstream`: {0}")] + SecretStream(crypto_secretstream::aead::Error), + /// Error from [`std::io`] + #[error("{0}")] + FromStdIo(#[from] std::io::Error), +} + +impl From for Error { + fn from(value: crypto_secretstream::aead::Error) -> Self { + Error::SecretStream(value) + } +} + +impl From for std::io::Error { + fn from(value: Error) -> Self { + std::io::Error::other(value) + } +} diff --git a/src/lib.rs b/src/lib.rs index 531a068..892aa5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,141 +4,51 @@ //! protocol implementation in [the original Javascript version][holepunch-hypercore] aiming //! for interoperability with LTS version. //! -//! This crate is built on top of the [hypercore] crate, which defines some structs used here. +//! This crate is built on top of the [hypercore](https://crates.io/crates/hypercore) crate, which defines some structs used here. //! //! ## Design //! -//! This crate does not include any IO related code, it is up to the user to supply a streaming IO -//! handler that implements the [AsyncRead] and [AsyncWrite] traits. +//! This crate expects to receive a pre-encrypted, message-framed connection (e.g., from hyperswarm). +//! The underlying stream should implement `Stream> + Sink>`. //! -//! When opening a Hypercore protocol stream on an IO handler, the protocol will perform a Noise -//! handshake followed by libsodium's [crypto_secretstream] to setup a secure and authenticated -//! connection. After that, each side can request any number of channels on the protocol. A +//! After construction, each side can request any number of channels on the protocol. A //! channel is opened with a [Key], a 32 byte buffer. Channels are only opened if both peers //! opened a channel for the same key. It is automatically verified that both parties know the -//! key without transmitting the key itself. +//! key without transmitting the key itself using the handshake hash from the underlying connection. //! //! On a channel, the predefined messages, including a custom Extension message, of the Hypercore //! protocol can be sent and received. //! -//! ## Features -//! -//! ### `sparse` (default) -//! -//! When using disk storage for hypercore, clearing values may create sparse files. On by default. -//! -//! ### `async-std` (default) -//! -//! Use the async-std runtime, on by default. Either this or `tokio` is mandatory. -//! -//! ### `tokio` -//! -//! Use the tokio runtime. Either this or `async_std` is mandatory. -//! -//! ### `wasm-bindgen` -//! -//! Enable for WASM runtime support. -//! -//! ### `cache` -//! -//! Use a moka cache for hypercore's merkle tree nodes to speed-up reading. -//! -//! ## Example -//! -//! The following example opens a TCP server on localhost and connects to that server. Both ends -//! then open a channel with the same key and exchange a message. -//! -//! ```no_run -//! # async_std::task::block_on(async { -//! use hypercore_protocol::{ProtocolBuilder, Event, Message}; -//! use hypercore_protocol::schema::*; -//! use async_std::prelude::*; -//! // Start a tcp server. -//! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); -//! async_std::task::spawn(async move { -//! let mut incoming = listener.incoming(); -//! while let Some(Ok(stream)) = incoming.next().await { -//! async_std::task::spawn(async move { -//! onconnection(stream, false).await -//! }); -//! } -//! }); -//! -//! // Connect a client. -//! let stream = async_std::net::TcpStream::connect("localhost:8000").await.unwrap(); -//! onconnection(stream, true).await; -//! -//! /// Start Hypercore protocol on a TcpStream. -//! async fn onconnection (stream: async_std::net::TcpStream, is_initiator: bool) { -//! // A peer either is the initiator or a connection or is being connected to. -//! let name = if is_initiator { "dialer" } else { "listener" }; -//! // A key for the channel we want to open. Usually, this is a pre-shared key that both peers -//! // know about. -//! let key = [3u8; 32]; -//! // Create the protocol. -//! let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); -//! -//! // Iterate over the protocol events. This is required to "drive" the protocol. -//! -//! while let Some(Ok(event)) = protocol.next().await { -//! eprintln!("{} received event {:?}", name, event); -//! match event { -//! // The handshake event is emitted after the protocol is fully established. -//! Event::Handshake(_remote_key) => { -//! protocol.open(key.clone()).await; -//! }, -//! // A Channel event is emitted for each established channel. -//! Event::Channel(mut channel) => { -//! // A Channel can be sent to other tasks. -//! async_std::task::spawn(async move { -//! // A Channel can both send messages and is a stream of incoming messages. -//! channel.send(Message::Want(Want { start: 0, length: 1 })).await; -//! while let Some(message) = channel.next().await { -//! eprintln!("{} received message: {:?}", name, message); -//! } -//! }); -//! }, -//! _ => {} -//! } -//! } -//! } -//! # }) -//! ``` -//! -//! Find more examples in the [Github repository][examples]. -//! //! [holepunch-hypercore]: https://github.com/holepunchto/hypercore -//! [datrs-hypercore]: https://github.com/datrs/hypercore -//! [AsyncRead]: futures_lite::AsyncRead -//! [AsyncWrite]: futures_lite::AsyncWrite -//! [examples]: https://github.com/datrs/hypercore-protocol-rs#examples -#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)] +#![forbid(unsafe_code)] #![deny(missing_debug_implementations, nonstandard_style)] -#![warn(missing_docs, unreachable_pub)] +#![warn(missing_docs, clippy::needless_pass_by_ref_mut, unreachable_pub)] -mod builder; mod channels; mod constants; mod crypto; -mod duplex; +mod error; mod message; +mod mqueue; mod protocol; -mod reader; +mod stream; +#[cfg(test)] +mod test_utils; mod util; -mod writer; /// The wire messages used by the protocol. pub mod schema; +pub use error::Error; -pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, }; -pub use duplex::Duplex; -pub use hypercore; // Re-export hypercore pub use message::Message; pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; +pub use stream::BoxedStream; pub use util::discovery_key; +// Export handshake result for constructing Protocol +pub use crypto::handshake::HandshakeResult; diff --git a/src/message.rs b/src/message.rs index 27b74c1..11c91fd 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,353 +1,122 @@ use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, +use compact_encoding::{ + CompactEncoding, EncodingError, EncodingErrorKind, VecEncodable, decode_usize, take_array, + write_array, }; use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; - -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Raw, - Message, -} - -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; -} - -impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { - Ok(self.len()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } +use std::{fmt, io}; +use tracing::{debug, instrument, trace, warn}; + +const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; +const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; +const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; +const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; + +#[instrument(skip_all err)] +pub(crate) fn decode_unframed_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + let og_len = buf.len(); + if og_len >= 3 && buf[0] == 0x00 { + // batch of NOT open/close messages + if buf[1] == 0x00 { + let (_, mut buf) = take_array::<2>(buf)?; + // Batch of messages + let mut messages: Vec = vec![]; + + // First, there is the original channel + let mut current_channel; + (current_channel, buf) = u64::decode(buf)?; + while !buf.is_empty() { + // Length of the message is inbetween here + let channel_message_length; + (channel_message_length, buf) = decode_usize(buf)?; + if channel_message_length > buf.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length: [{channel_message_length}] +\tbut we have [{}] remaining bytes. +\tInitial buffer size [{og_len}]", + buf.len() + ), + )); } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } + // Then the actual message + let channel_message; + let bl = buf.len(); + (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?; + trace!( + "Decoded ChannelMessage::{:?} using [{} bytes]", + channel_message.message, + bl - buf.len() + ); + messages.push(channel_message); + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if !buf.is_empty() && buf[0] == 0x00 { + (current_channel, buf) = u64::decode(buf)?; } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) } - } else if buf.len() >= 2 { - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) + Ok((messages, og_len - buf.len())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) } else { Err(io::Error::new( io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), + "received invalid special message", )) } - } - - fn preencode(&mut self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) + } else if buf.len() >= 2 { + trace!("Decoding single ChannelMessage"); + // Single message + let og_len = buf.len(); + let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?; + Ok((vec![channel_message], og_len - buf.len())) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:?}"), + )) } } -impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } +fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result { + Ok(match messages { + [] => 0, + [msg] => match msg.message { + Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?, + _ => msg.encoded_size()?, + }, + msgs => { + let mut out = MULTI_MESSAGE_PREFIX.len(); + let mut current_channel: u64 = messages[0].channel; + out += current_channel.encoded_size()?; + for message in msgs.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?; + current_channel = message.channel; } + let message_length = message.message.encoded_size()?; + out += message_length + (message_length as u64).encoded_size()?; } - }; - Ok(len) - } + out + } + }) } /// A protocol message. #[derive(Debug, Clone, PartialEq)] -#[allow(missing_docs)] +#[expect(missing_docs)] pub enum Message { Open(Open), Close(Close), @@ -365,6 +134,114 @@ pub enum Message { LocalSignal((String, Vec)), } +macro_rules! message_from { + ($($val:ident),+) => { + $( + impl From<$val> for Message { + fn from(value: $val) -> Self { + Message::$val(value) + } + } + )* + } +} +message_from!( + Open, + Close, + Synchronize, + Request, + Cancel, + Data, + NoData, + Want, + Unwant, + Bitfield, + Range, + Extension +); + +macro_rules! decode_message { + ($type:ty, $buf:expr) => {{ + let (x, rest) = <$type>::decode($buf)?; + (Message::from(x), rest) + }}; +} + +impl CompactEncoding for Message { + fn encoded_size(&self) -> Result { + let typ_size = if let Self::Open(_) | Self::Close(_) = &self { + 0 + } else { + self.typ().encoded_size()? + }; + let msg_size = match self { + Self::LocalSignal(_) => Ok(0), + Self::Open(x) => x.encoded_size(), + Self::Close(x) => x.encoded_size(), + Self::Synchronize(x) => x.encoded_size(), + Self::Request(x) => x.encoded_size(), + Self::Cancel(x) => x.encoded_size(), + Self::Data(x) => x.encoded_size(), + Self::NoData(x) => x.encoded_size(), + Self::Want(x) => x.encoded_size(), + Self::Unwant(x) => x.encoded_size(), + Self::Bitfield(x) => x.encoded_size(), + Self::Range(x) => x.encoded_size(), + Self::Extension(x) => x.encoded_size(), + }?; + Ok(typ_size + msg_size) + } + + #[instrument(skip_all, fields(name = self.name()))] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + debug!("Encoding {self:?}"); + let rest = if let Self::Open(_) | Self::Close(_) = &self { + buffer + } else { + self.typ().encode(buffer)? + }; + match self { + Self::Open(x) => x.encode(rest), + Self::Close(x) => x.encode(rest), + Self::Synchronize(x) => x.encode(rest), + Self::Request(x) => x.encode(rest), + Self::Cancel(x) => x.encode(rest), + Self::Data(x) => x.encode(rest), + Self::NoData(x) => x.encode(rest), + Self::Want(x) => x.encode(rest), + Self::Unwant(x) => x.encode(rest), + Self::Bitfield(x) => x.encode(rest), + Self::Range(x) => x.encode(rest), + Self::Extension(x) => x.encode(rest), + Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"), + } + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (typ, rest) = u64::decode(buffer)?; + Ok(match typ { + 0 => decode_message!(Synchronize, rest), + 1 => decode_message!(Request, rest), + 2 => decode_message!(Cancel, rest), + 3 => decode_message!(Data, rest), + 4 => decode_message!(NoData, rest), + 5 => decode_message!(Want, rest), + 6 => decode_message!(Unwant, rest), + 7 => decode_message!(Bitfield, rest), + 8 => decode_message!(Range, rest), + 9 => decode_message!(Extension, rest), + _ => { + return Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )); + } + }) + } +} impl Message { /// Wire type of this message. pub(crate) fn typ(&self) -> u64 { @@ -382,71 +259,23 @@ impl Message { value => unimplemented!("{} does not have a type", value), } } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { + /// Get the name of the message + pub fn name(&self) -> &'static str { match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { - match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) + Message::Open(_) => "Open", + Message::Close(_) => "Close", + Message::Synchronize(_) => "Synchronize", + Message::Request(_) => "Request", + Message::Cancel(_) => "Cancel", + Message::Data(_) => "Data", + Message::NoData(_) => "NoData", + Message::Want(_) => "Want", + Message::Unwant(_) => "Unwant", + Message::Bitfield(_) => "Bitfield", + Message::Range(_) => "Range", + Message::Extension(_) => "Extension", + Message::LocalSignal(_) => "LocalSignal", + } } } @@ -479,7 +308,6 @@ impl fmt::Display for Message { pub(crate) struct ChannelMessage { pub(crate) channel: u64, pub(crate) message: Message, - state: Option, } impl PartialEq for ChannelMessage { @@ -494,14 +322,21 @@ impl fmt::Debug for ChannelMessage { } } +impl fmt::Display for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ChannelMessage {{ channel {}, message {} }}", + self.channel, + self.message.name() + ) + } +} + impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } + Self { channel, message } } /// Consume self and return (channel, Message). @@ -513,23 +348,24 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { + debug!("Decode ChannelMessage::Open"); + let og_len = buf.len(); + if og_len <= 5 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Open message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; + let (open_msg, buf) = Open::decode(buf)?; Ok(( Self { channel: open_msg.channel, message: Message::Open(open_msg), - state: None, }, - state.start(), + og_len - buf.len(), )) } @@ -538,107 +374,154 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Close"); + let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Close message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; + let (close, buf) = Close::decode(buf)?; Ok(( Self { - channel: close_msg.channel, - message: Message::Close(close_msg), - state: None, + channel: close.channel, + message: Message::Close(close), }, - state.start(), + og_len - buf.len(), )) } + #[instrument(err, skip_all)] + pub(crate) fn decode_from_channel_and_message( + buf: &[u8], + ) -> Result<(Self, &[u8]), EncodingError> { + //::decode(buf) + let (channel, buf) = u64::decode(buf)?; + let (message, buf) = ::decode(buf)?; + debug!( + "Decode ChannelMessage{{ channel: {channel}, message: {} }}", + message.name() + ); + Ok((Self { channel, message }, buf)) + } /// Decode a normal channel message from a buffer. /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + #[instrument(err, skip(buf))] + pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, - "received empty message", + format!("received empty message [{buf:?}]"), )); } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) } +} - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) +/// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & +/// encode differently +impl CompactEncoding for ChannelMessage { + fn encoded_size(&self) -> Result { + let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message { + 0 + } else { + self.channel.encoded_size()? + }; + + Ok(channel_size + self.message.encoded_size()?) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = if let Message::Open(_) | Message::Close(_) = &self.message { + buffer + } else { + self.channel.encode(buffer)? + }; + ::encode(&self.message, rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + ChannelMessage::decode_from_channel_and_message(buffer) } } -impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) +impl VecEncodable for ChannelMessage { + #[instrument(skip_all, ret)] + fn vec_encoded_size(vec: &[Self]) -> Result + where + Self: Sized, + { + vec_channel_messages_encoded_size(vec) } - fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); - if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; + #[instrument(skip_all, err)] + fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> + where + Self: Sized, + { + let in_buf_len = buffer.len(); + trace!( + "Vec::encode to buf.len() = [{}]", + buffer.len() + ); + let mut rest = buffer; + match vec { + [] => Ok(rest), + [msg] => { + rest = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?, + _ => msg.channel.encode(rest)?, + }; + msg.message.encode(rest) + } + msgs => { + rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?; + let mut current_channel: u64 = msgs[0].channel; + rest = current_channel.encode(rest)?; + for msg in msgs { + if msg.channel != current_channel { + rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?; + rest = msg.channel.encode(rest)?; + current_channel = msg.channel; + } + let msg_len = msg.message.encoded_size()?; + rest = (msg_len as u64).encode(rest)?; + rest = msg.message.encode(rest)?; + } + trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len()); + Ok(rest) + } } - Ok(state.start()) + } + + fn vec_decode(buffer: &[u8]) -> Result<(Vec, &[u8]), EncodingError> + where + Self: Sized, + { + let mut combined_messages: Vec = vec![]; + let mut rest = buffer; + while !rest.is_empty() { + let (msgs, length) = decode_unframed_channel_messages(rest) + .map_err(|e| EncodingError::external(&format!("{e}")))?; + rest = &rest[length..]; + combined_messages.extend(msgs); + } + Ok((combined_messages, rest)) } } #[cfg(test)] mod tests { use super::*; - use hypercore::{ + use hypercore_schema::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, }; @@ -646,19 +529,20 @@ mod tests { ($( $msg:expr ),*) => { $( let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); - let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); + let channel_message = ChannelMessage::new(channel, $msg); + let encoded_size = channel_message.encoded_size()?; + let mut buf = vec![0u8; encoded_size]; + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); )* } } #[test] - fn message_encode_decode() { + fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ fork: 0, @@ -685,7 +569,9 @@ mod tests { upgrade: Some(RequestUpgrade { start: 0, length: 10 - }) + }), + manifest: false, + priority: 0 }), Message::Cancel(Cancel { request: 1, @@ -739,5 +625,29 @@ mod tests { message: vec![0x44, 20] }) }; + Ok(()) + } + + #[test] + fn enc_dec_vec_chan_message() -> Result<(), EncodingError> { + let one = Message::Synchronize(Synchronize { + fork: 0, + length: 4, + remote_length: 0, + downloading: true, + uploading: true, + can_upgrade: true, + }); + let two = Message::Range(Range { + drop: false, + start: 0, + length: 4, + }); + let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; + let buff = msgs.to_encoded_bytes()?; + let (result, rest) = as CompactEncoding>::decode(&buff)?; + assert!(rest.is_empty()); + assert_eq!(result, msgs); + Ok(()) } } diff --git a/src/mqueue.rs b/src/mqueue.rs new file mode 100644 index 0000000..e732491 --- /dev/null +++ b/src/mqueue.rs @@ -0,0 +1,150 @@ +//! Interface for reading and writing messages to a Stream/Sink +//! +//! This module handles encoding/decoding of `ChannelMessage` to/from raw bytes +//! over an already-encrypted connection (e.g., from hyperswarm). + +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use compact_encoding::CompactEncoding as _; +use futures::{Sink, Stream}; +use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN}; +use tracing::{error, instrument, trace}; + +use crate::message::ChannelMessage; + +/// Message IO layer that encodes/decodes `ChannelMessage` over a byte stream. +/// +/// This expects the underlying stream to already be encrypted and framed +/// (e.g., a hyperswarm `Connection`). +pub(crate) struct MessageIo { + stream: Box, + write_queue: VecDeque, +} + +impl Debug for MessageIo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessageIo") + .field("write_queue", &self.write_queue) + .finish() + } +} + +impl MessageIo { + /// Create a new MessageIo from a stream. + /// + /// The stream should be an already-encrypted, message-framed connection + /// (e.g., hyperswarm's `Connection` which implements `Stream` + /// where `CipherEvent::Message` contains the decrypted bytes). + pub(crate) fn new(stream: Box) -> Self { + Self { + stream, + write_queue: Default::default(), + } + } + pub(crate) fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> { + self.stream.remote_public_key() + } + pub(crate) fn local_public_key(&self) -> [u8; PUBLIC_KEYLEN] { + self.stream.local_public_key() + } + pub(crate) fn handshake_hash(&self) -> Option> { + self.stream.handshake_hash() + } + + /// Enqueue an outgoing message + pub(crate) fn enqueue(&mut self, msg: ChannelMessage) { + self.write_queue.push_back(msg) + } + + /// Drive outgoing messages + #[instrument(skip_all)] + pub(crate) fn poll_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut pending = true; + + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.stream), cx) { + pending = false; + if self.write_queue.is_empty() { + break; + } + + // Batch all queued messages + let mut messages = vec![]; + while let Some(msg) = self.write_queue.pop_front() { + messages.push(msg); + } + + let buf = match messages.to_encoded_bytes() { + Ok(x) => x, + Err(e) => { + error!(error = ?e, "error encoding messages"); + return Poll::Ready(Err(e.into())); + } + }; + + if let Err(e) = Sink::start_send(Pin::new(&mut self.stream), buf.to_vec()) { + return Poll::Ready(Err(e)); + } + + match Sink::poll_flush(Pin::new(&mut self.stream), cx) { + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} + } + } + + if pending { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + /// Poll for incoming messages + #[instrument(skip_all)] + pub(crate) fn poll_inbound( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(event)) => match event { + hypercore_handshake::CipherEvent::HandshakePayload(_x) => Poll::Pending, + hypercore_handshake::CipherEvent::Message(msg) => { + match >::decode(&msg) { + Ok((messages, _rest)) => { + for m in messages.iter() { + trace!("RX ChannelMessage::{m}"); + } + Poll::Ready(Some(Ok(messages))) + } + Err(e) => Poll::Ready(Some(Err(e.into()))), + } + } + hypercore_handshake::CipherEvent::ErrStuff(e) => Poll::Ready(Some(Err(e))), + }, + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl Stream for MessageIo { + type Item = Result>; + + #[instrument(skip_all)] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Drive outbound messages + let _ = self.poll_outbound(cx); + // Poll for inbound messages + self.poll_inbound(cx) + } +} diff --git a/src/protocol.rs b/src/protocol.rs index 7b8d468..3c716a9 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,24 +1,27 @@ use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_lite::stream::Stream; use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::future::Future; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; - -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; +use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN}; +use std::{ + collections::VecDeque, + convert::TryInto, + fmt, + io::{self, Result}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tracing::{error, instrument}; + +use crate::{ + channels::{Channel, ChannelMap}, + constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}, + crypto::HandshakeResult, + message::{ChannelMessage, Message}, + mqueue::MessageIo, + schema::*, + util::{map_channel_err, pretty_hash}, +}; macro_rules! return_error { ($msg:expr) => { @@ -29,31 +32,6 @@ macro_rules! return_error { } const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); - -/// Options for a Protocol instance. -#[derive(Debug)] -pub(crate) struct Options { - /// Whether this peer initiated the IO connection for this protoccol - pub(crate) is_initiator: bool, - /// Enable or disable the handshake. - /// Disabling the handshake will also disable capabilitity verification. - /// Don't disable this if you're not 100% sure you want this. - pub(crate) noise: bool, - /// Enable or disable transport encryption. - pub(crate) encrypted: bool, -} - -impl Options { - /// Create with default options. - pub(crate) fn new(is_initiator: bool) -> Self { - Self { - is_initiator, - noise: true, - encrypted: true, - } - } -} /// Remote public key (32 bytes). pub(crate) type RemotePublicKey = [u8; 32]; @@ -67,7 +45,7 @@ pub type Key = [u8; 32]; #[derive(PartialEq)] pub enum Event { /// Emitted after the handshake with the remote peer is complete. - /// This is the first event (if the handshake is not disabled). + /// This is the first event. Handshake(RemotePublicKey), /// Emitted when the remote peer opens a channel that we did not yet open. DiscoveryKey(DiscoveryKey), @@ -111,116 +89,77 @@ impl fmt::Debug for Event { } } -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - -/// A Protocol stream. -pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: IO, - state: State, - options: Options, - handshake: Option, +/// A Protocol stream for replicating hypercores over an encrypted connection. +/// +/// The protocol expects an already-encrypted, message-framed connection +/// (e.g., from hyperswarm). The `HandshakeResult` provides the handshake hash +/// and public keys needed for capability verification. +pub struct Protocol { + io: MessageIo, + is_initiator: bool, channels: ChannelMap, command_rx: Receiver, command_tx: CommandTx, outbound_rx: Receiver>, outbound_tx: Sender>, + #[allow(dead_code)] // TODO: Implement keepalive keepalive: Delay, queued_events: VecDeque, + handshake_emitted: bool, } -impl std::fmt::Debug for Protocol { +impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) - //.field("io", &self.io) - .field("state", &self.state) - .field("options", &self.options) - .field("handshake", &self.handshake) + .field("is_initiator", &self.is_initiator) .field("channels", &self.channels) - .field("command_rx", &self.command_rx) - .field("command_tx", &self.command_tx) - .field("outbound_rx", &self.outbound_rx) - .field("outbound_tx", &self.outbound_tx) - .field("keepalive", &self.keepalive) + .field("handshake_emitted", &self.handshake_emitted) .field("queued_events", &self.queued_events) .finish() } } -impl Protocol -where - IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, -{ +impl Protocol { /// Create a new protocol instance. - pub(crate) fn new(io: IO, options: Options) -> Self { + /// + /// # Arguments + /// * `stream` - An already-encrypted, message-framed connection (e.g., hyperswarm `Connection`) + pub fn new(stream: Box) -> Self { let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); let (outbound_tx, outbound_rx): ( Sender>, Receiver>, - ) = async_channel::bounded(1); + ) = async_channel::bounded(CHANNEL_CAP); + + let is_initiator = stream.is_initiator(); + Protocol { - io, - read_state: ReadState::new(), - write_state: WriteState::new(), - options, - state: State::NotInitialized, + io: MessageIo::new(stream), + is_initiator, channels: ChannelMap::new(), - handshake: None, command_rx, command_tx: CommandTx(command_tx), outbound_tx, outbound_rx, keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), queued_events: VecDeque::new(), + handshake_emitted: false, } } /// Whether this protocol stream initiated the underlying IO connection. pub fn is_initiator(&self) -> bool { - self.options.is_initiator + self.is_initiator } /// Get your own Noise public key. - /// - /// Empty before the handshake completed. - pub fn public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.local_pubkey.as_slice()), - } + pub fn public_key(&self) -> [u8; PUBLIC_KEYLEN] { + self.io.local_public_key() } /// Get the remote's Noise public key. - /// - /// Empty before the handshake completed. - pub fn remote_public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.remote_pubkey.as_slice()), - } + pub fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> { + self.io.remote_public_key() } /// Get a sender to send commands. @@ -229,7 +168,7 @@ where } /// Give a command to the protocol. - pub async fn command(&mut self, command: Command) -> Result<()> { + pub async fn command(&self, command: Command) -> Result<()> { self.command_tx.send(command).await } @@ -237,7 +176,7 @@ where /// /// Once the other side proofed that it also knows the `key`, the channel is emitted as /// `Event::Channel` on the protocol event stream. - pub async fn open(&mut self, key: Key) -> Result<()> { + pub async fn open(&self, key: Key) -> Result<()> { self.command_tx.open(key).await } @@ -246,16 +185,27 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { - self.io - } - + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let State::NotInitialized = this.state { - return_error!(this.init()); + // Initiator needs to send and receive a message before proceeding + if this.is_initiator && this.io.handshake_hash().is_none() { + return_error!(this.poll_outbound_write(cx)); + return_error!(this.poll_inbound_read(cx)); + if this.io.handshake_hash().is_none() { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + // Emit handshake event on first poll + if !this.handshake_emitted { + if let Some(remote_pubkey) = this.io.remote_public_key() { + this.handshake_emitted = true; + return Poll::Ready(Ok(Event::Handshake(remote_pubkey))); + } else { + cx.waker().wake_by_ref(); + } } // Drain queued events first. @@ -266,10 +216,8 @@ where // Read and process incoming messages. return_error!(this.poll_inbound_read(cx)); - if let State::Established = this.state { - // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); - } + // Check for commands. + return_error!(this.poll_commands(cx)); // Poll the keepalive timer. this.poll_keepalive(cx); @@ -285,43 +233,21 @@ where } } - fn init(&mut self) -> Result<()> { - tracing::debug!( - "protocol init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; + if let Err(e) = self.on_command(command) { + error!(error = ?e, "Error handling command"); + return Err(e); + } } Ok(()) } - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + /// TODO Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&self, _cx: &mut Context<'_>) { + /* + const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -330,8 +256,10 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } + */ } + // just handles Close and LocalSignal fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel. if let ChannelMessage { @@ -354,37 +282,37 @@ where true } - /// Poll for inbound messages and processs them. + /// Poll for inbound messages and process them. + #[instrument(skip_all, err)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; + match self.io.poll_inbound(cx) { + Poll::Ready(Some(result)) => { + let messages = result?; + self.on_inbound_channel_messages(messages)?; } - Poll::Ready(Err(e)) => return Err(e), + Poll::Ready(None) => return Ok(()), Poll::Pending => return Ok(()), } } } /// Poll for outbound messages and write them. + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + // Drive outbound IO + if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + error!(err = ?e, "error from poll_outbound"); return Err(e); } - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); - } - + // Send messages from outbound_rx match Pin::new(&mut self.outbound_rx).poll_next(cx) { Poll::Ready(Some(mut messages)) => { if !messages.is_empty() { messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - self.write_state.park_frame(frame); + for msg in messages { + self.io.enqueue(msg); } } } @@ -394,121 +322,16 @@ where } } - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - let previous_state = if self.options.encrypted { - State::SecretStream(None) - } else { - State::Handshake(None) - }; - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); + #[instrument(skip_all)] + fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result); - } - Ok(()) - } - - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; Ok(()) } + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { - // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); match message { Message::Open(msg) => self.on_open(remote_id, msg)?, @@ -520,6 +343,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -529,6 +353,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); @@ -552,8 +377,7 @@ where capability, }); let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); + self.io.enqueue(channel_message); Ok(()) } @@ -570,6 +394,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -586,18 +411,16 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_queue_direct(&mut frame) - } - + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; - self.verify_remote_capability(remote_capability.cloned(), key)?; + self.verify_remote_capability(remote_capability.cloned(), key) + .expect("TODO channel can only be accepted after first message")?; let channel = self.channels.accept(local_id, self.outbound_tx.clone())?; self.queue_event(Event::Channel(channel)); Ok(()) @@ -624,57 +447,77 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { - match self.handshake.as_ref() { - Some(handshake) => handshake.capability(key), - None => None, - } - } - - fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { - match self.handshake.as_ref() { - Some(handshake) => handshake.verify_remote_capability(capability, key), - None => Err(Error::new( - ErrorKind::PermissionDenied, - "Missing handshake state for capability verification", - )), - } + let is_initiator = self.is_initiator; + let remote_pubkey = self.remote_public_key()?; + let local_pubkey = self.public_key(); + let handshake_hash = self.io.handshake_hash()?; + HandshakeResult::from_pre_encrypted( + is_initiator, + local_pubkey, + remote_pubkey, + handshake_hash.to_vec(), + ) + .capability(key) + } + + #[instrument(skip_all)] + fn verify_remote_capability( + &self, + capability: Option>, + key: &[u8], + ) -> Option> { + let is_initiator = self.is_initiator; + let remote_pubkey = self.remote_public_key()?; + let local_pubkey = self.public_key(); + let handshake_hash = self.io.handshake_hash()?; + Some( + HandshakeResult::from_pre_encrypted( + is_initiator, + local_pubkey, + remote_pubkey, + handshake_hash.to_vec(), + ) + .verify_remote_capability(capability, key), + ) } } -impl Stream for Protocol -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ +impl Stream for Protocol { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) + match Protocol::poll_next(self, cx) { + Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } } -/// Send [Command](Command)s to the [Protocol](Protocol). +/// Send [`Command`]s to the [`Protocol`]. #[derive(Clone, Debug)] pub struct CommandTx(Sender); impl CommandTx { /// Send a protocol command - pub async fn send(&mut self, command: Command) -> Result<()> { + pub async fn send(&self, command: Command) -> Result<()> { self.0.send(command).await.map_err(map_channel_err) } /// Open a protocol channel. /// /// The channel will be emitted on the main protocol. - pub async fn open(&mut self, key: Key) -> Result<()> { + pub async fn open(&self, key: Key) -> Result<()> { self.send(Command::Open(key)).await } /// Close a protocol channel. - pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { + pub async fn close(&self, discovery_key: DiscoveryKey) -> Result<()> { self.send(Command::Close(discovery_key)).await } /// Send a local signal event to the protocol. - pub async fn signal_local(&mut self, name: &str, data: Vec) -> Result<()> { + pub async fn signal_local(&self, name: &str, data: Vec) -> Result<()> { self.send(Command::SignalLocal((name.to_string(), data))) .await } diff --git a/src/reader.rs b/src/reader.rs deleted file mode 100644 index 51b370b..0000000 --- a/src/reader.rs +++ /dev/null @@ -1,231 +0,0 @@ -use crate::crypto::DecryptCipher; -use futures_lite::io::AsyncRead; -use futures_timer::Delay; -use std::future::Future; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE}; -use crate::message::{Frame, FrameType}; -use crate::util::stat_uint24_le; -use std::time::Duration; - -const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64); -const READ_BUF_INITIAL_SIZE: usize = 1024 * 128; - -#[derive(Debug)] -pub(crate) struct ReadState { - /// The read buffer. - buf: Vec, - /// The start of the not-yet-processed byte range in the read buffer. - start: usize, - /// The end of the not-yet-processed byte range in the read buffer. - end: usize, - /// The logical state of the reading (either header or body). - step: Step, - /// The timeout after which the connection is closed. - timeout: Delay, - /// Optional decryption cipher. - cipher: Option, - /// The frame type to be passed to the decoder. - frame_type: FrameType, -} - -impl ReadState { - pub(crate) fn new() -> ReadState { - ReadState { - buf: vec![0u8; READ_BUF_INITIAL_SIZE], - start: 0, - end: 0, - step: Step::Header, - timeout: Delay::new(TIMEOUT), - cipher: None, - frame_type: FrameType::Raw, - } - } -} - -#[derive(Debug)] -enum Step { - Header, - Body { - header_len: usize, - body_len: usize, - }, - /// Multiple messages one after another - Batch, -} - -impl ReadState { - pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) { - self.cipher = Some(decrypt_cipher); - } - - /// Decrypts a given buf with stored cipher, if present. Used to correct - /// the rare mistake that more than two messages came in where the first - /// one created the cipher, and the next one should have been decrypted - /// but wasn't. - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result> { - if let Some(cipher) = self.cipher.as_mut() { - Ok(cipher.decrypt_buf(buf)?.0) - } else { - Ok(buf.to_vec()) - } - } - - pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) { - self.frame_type = frame_type; - } - - pub(crate) fn poll_reader( - &mut self, - cx: &mut Context<'_>, - mut reader: &mut R, - ) -> Poll> - where - R: AsyncRead + Unpin, - { - let mut incomplete = true; - loop { - if !incomplete { - if let Some(result) = self.process() { - return Poll::Ready(result); - } - } else { - incomplete = false; - } - let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) { - Poll::Ready(Ok(n)) if n > 0 => n, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - // If the reader is pending, poll the timeout. - Poll::Pending | Poll::Ready(Ok(_)) => { - // Return Pending if the timeout is pending, or an error if the - // timeout expired (i.e. returned Poll::Ready). - return Pin::new(&mut self.timeout) - .poll(cx) - .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out"))); - } - }; - - let end = self.end + n; - let (success, segments) = create_segments(&self.buf[self.start..end])?; - if success { - if let Some(ref mut cipher) = self.cipher { - let mut dec_end = self.start; - for (index, header_len, body_len) in segments { - let de = cipher.decrypt( - &mut self.buf[self.start + index..end], - header_len, - body_len, - )?; - dec_end = self.start + index + de; - } - self.end = dec_end; - } else { - self.end = end; - } - } else { - // Could not segment due to buffer being full, need to cycle the buffer - // and possibly resize it too if the message is too big. - self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]); - - // Set incomplete flag to skip processing and instead poll more data - incomplete = true; - } - self.timeout.reset(TIMEOUT); - } - } - - fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { - let (last_index, last_header_len, last_body_len) = last_segment; - let total_incoming_length = last_index + last_header_len + last_body_len; - if self.buf.len() < total_incoming_length { - // The incoming segments will not fit into the buffer, need to resize it - self.buf.resize(total_incoming_length, 0u8); - } - let temp = self.buf[self.start..].to_vec(); - let len = temp.len(); - self.buf[..len].copy_from_slice(&temp[..]); - self.end = len; - self.start = 0; - } - - fn process(&mut self) -> Option> { - loop { - match self.step { - Step::Header => { - let stat = stat_uint24_le(&self.buf[self.start..self.end]); - if let Some((header_len, body_len)) = stat { - if body_len == 0 { - // This is a keepalive message, just remain in Step::Header - self.start += header_len; - return None; - } else if (self.start + header_len + body_len as usize) < self.end { - // There are more than one message here, create a batch from all of - // then - self.step = Step::Batch; - } else { - let body_len = body_len as usize; - if body_len > MAX_MESSAGE_SIZE as usize { - return Some(Err(Error::new( - ErrorKind::InvalidData, - "Message length above max allowed size", - ))); - } - self.step = Step::Body { - header_len, - body_len, - }; - } - } else { - return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header"))); - } - } - - Step::Body { - header_len, - body_len, - } => { - let message_len = header_len + body_len; - let range = self.start + header_len..self.start + message_len; - let frame = Frame::decode(&self.buf[range], &self.frame_type); - self.start += message_len; - self.step = Step::Header; - return Some(frame); - } - Step::Batch => { - let frame = - Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); - self.start = self.end; - self.step = Step::Header; - return Some(frame); - } - } - } - } -} - -#[allow(clippy::type_complexity)] -fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { - let mut index: usize = 0; - let len = buf.len(); - let mut segments: Vec<(usize, usize, usize)> = vec![]; - while index < len { - if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) { - let body_len = body_len as usize; - segments.push((index, header_len, body_len)); - if len < index + header_len + body_len { - // The segments will not fit, return false to indicate that more needs to be read - return Ok((false, segments)); - } - index += header_len + body_len; - } else { - return Err(Error::new( - ErrorKind::InvalidData, - "Could not read header while decrypting", - )); - } - } - Ok((true, segments)) -} diff --git a/src/schema.rs b/src/schema.rs index ef58e77..09e5122 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,7 +1,11 @@ -use hypercore::encoding::{CompactEncoding, EncodingError, HypercoreState, State}; -use hypercore::{ +use compact_encoding::{ + CompactEncoding, EncodingError, map_decode, map_encode, sum_encoded_size, take_array, + take_array_mut, write_array, write_slice, +}; +use hypercore_schema::{ DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; +use tracing::instrument; /// Open message #[derive(Debug, Clone, PartialEq)] @@ -16,46 +20,55 @@ pub struct Open { pub capability: Option>, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; - if value.capability.is_some() { - self.add_end(1)?; // flags for future use - self.preencode_fixed_32()?; +impl CompactEncoding for Open { + #[instrument(skip_all, ret, err)] + fn encoded_size(&self) -> Result { + let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); + if self.capability.is_some() { + return Ok( + out + + 1 // flags for future use + + 32, // TODO capabalilities buff should always be 32 bytes, but it's a vec + ); } - Ok(self.end()) + Ok(out) } - fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer)?; - self.encode(&value.protocol, buffer)?; - self.encode(&value.discovery_key, buffer)?; - if let Some(capability) = &value.capability { - self.add_start(1)?; // flags for future use - self.encode_fixed_32(capability, buffer)?; + #[instrument(skip_all)] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); + if let Some(cap) = &self.capability { + let (_, rest) = take_array_mut::<1>(rest)?; + return write_slice(cap, rest); } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - let protocol: String = self.decode(buffer)?; - let discovery_key: Vec = self.decode(buffer)?; - let capability: Option> = if self.start() < self.end() { - self.add_start(1)?; // flags for future use - let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); - Some(capability) + Ok(rest) + } + + #[instrument(skip_all, err)] + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((channel, protocol, discovery_key), rest) = + map_decode!(buffer, [u64, String, Vec]); + // NB: Open/Close are only sent alone in their own Frame. So we're done when there is no + // more data + let (capability, rest) = if !rest.is_empty() { + let (_, rest) = take_array::<1>(rest)?; + let (capability, rest) = take_array::<32>(rest)?; + (Some(capability.to_vec()), rest) } else { - None + (None, rest) }; - Ok(Open { - channel, - protocol, - discovery_key, - capability, - }) + Ok(( + Self { + channel, + protocol, + discovery_key, + capability, + }, + rest, + )) } } @@ -66,18 +79,21 @@ pub struct Close { pub channel: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Close) -> Result { - self.preencode(&value.channel) +impl CompactEncoding for Close { + fn encoded_size(&self) -> Result { + self.channel.encoded_size() } - fn encode(&mut self, value: &Close, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encode(buffer) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - Ok(Close { channel }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (channel, rest) = u64::decode(buffer)?; + Ok((Self { channel }, rest)) } } @@ -98,40 +114,44 @@ pub struct Synchronize { pub can_upgrade: bool, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Synchronize) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.fork)?; - self.preencode(&value.length)?; - self.preencode(&value.remote_length) - } - - fn encode(&mut self, value: &Synchronize, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.can_upgrade { 1 } else { 0 }; - flags |= if value.uploading { 2 } else { 0 }; - flags |= if value.downloading { 4 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.fork, buffer)?; - self.encode(&value.length, buffer)?; - self.encode(&value.remote_length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let fork: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - let remote_length: u64 = self.decode(buffer)?; +impl CompactEncoding for Synchronize { + fn encoded_size(&self) -> Result { + Ok(1 + sum_encoded_size!(self.fork, self.length, self.remote_length)) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; + flags |= if self.uploading { 2 } else { 0 }; + flags |= if self.downloading { 4 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + Ok(map_encode!( + rest, + self.fork, + self.length, + self.remote_length + )) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((fork, length, remote_length), rest) = map_decode!(rest, [u64, u64, u64]); let can_upgrade = flags & 1 != 0; let uploading = flags & 2 != 0; let downloading = flags & 4 != 0; - Ok(Synchronize { - fork, - length, - remote_length, - can_upgrade, - uploading, - downloading, - }) + Ok(( + Synchronize { + fork, + length, + remote_length, + can_upgrade, + uploading, + downloading, + }, + rest, + )) } } @@ -150,83 +170,108 @@ pub struct Request { pub seek: Option, /// Request upgrade pub upgrade: Option, + // TODO what is this + /// Request manifest + pub manifest: bool, + // TODO what is this + // this could prob be usize + /// Request priority + pub priority: u64, } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Request) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.id)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; +macro_rules! maybe_decode { + ($cond:expr, $type:ty, $buf:ident) => { + if $cond { + let (result, rest) = <$type>::decode($buf)?; + (Some(result), rest) + } else { + (None, $buf) } - if let Some(hash) = &value.hash { - self.preencode(hash)?; + }; +} + +impl CompactEncoding for Request { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self.id, self.fork); + if let Some(block) = &self.block { + out += block.encoded_size()?; } - if let Some(seek) = &value.seek { - self.preencode(seek)?; + if let Some(hash) = &self.hash { + out += hash.encoded_size()?; } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; + if let Some(seek) = &self.seek { + out += seek.encoded_size()?; } - Ok(self.end()) - } - - fn encode(&mut self, value: &Request, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.id, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; + if let Some(upgrade) = &self.upgrade { + out += upgrade.encoded_size()?; + } + if self.priority != 0 { + out += self.priority.encoded_size()?; + } + Ok(out) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + flags |= if self.manifest { 16 } else { 0 }; + flags |= if self.priority != 0 { 32 } else { 0 }; + let mut rest = write_array(&[flags], buffer)?; + rest = map_encode!(rest, self.id, self.fork); + + if let Some(block) = &self.block { + rest = block.encode(rest)?; } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; + if let Some(hash) = &self.hash { + rest = hash.encode(rest)?; } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; + if let Some(seek) = &self.seek { + rest = seek.encode(rest)?; } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; + if let Some(upgrade) = &self.upgrade { + rest = upgrade.encode(rest)?; } - Ok(self.start()) + + if self.priority != 0 { + rest = self.priority.encode(rest)?; + } + + Ok(rest) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let id: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((id, fork), rest) = map_decode!(rest, [u64, u64]); + + let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + let manifest = flags & 16 != 0; + let (priority, rest) = if flags & 32 != 0 { + u64::decode(rest)? } else { - None + (0, rest) }; - Ok(Request { - id, - fork, - block, - hash, - seek, - upgrade, - }) + Ok(( + Request { + id, + fork, + block, + hash, + seek, + upgrade, + manifest, + priority, + }, + rest, + )) } } @@ -237,18 +282,21 @@ pub struct Cancel { pub request: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Cancel) -> Result { - self.preencode(&value.request) +impl CompactEncoding for Cancel { + fn encoded_size(&self) -> Result { + self.request.encoded_size() } - fn encode(&mut self, value: &Cancel, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encode(buffer) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(Cancel { request }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Cancel { request }, rest)) } } @@ -269,93 +317,86 @@ pub struct Data { pub upgrade: Option, } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Data) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.request)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; +macro_rules! opt_encoded_size { + ($opt:expr, $sum:ident) => { + if let Some(thing) = $opt { + $sum += thing.encoded_size()?; } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Data, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.request, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } + }; +} - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let request: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) +// TODO we could write a macro where it takes a $cond that returns an opt. +// if the option is Some(T) then do T::encode(buf) +// also if some add $flag. +// This would simplify some of these impls +macro_rules! opt_encoded_bytes { + ($opt:expr, $buf:ident) => { + if let Some(thing) = $opt { + thing.encode($buf)? } else { - None - }; - Ok(Data { - request, - fork, - block, - hash, - seek, - upgrade, - }) + $buf + } + }; +} +impl CompactEncoding for Data { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self.request, self.fork); + opt_encoded_size!(&self.block, out); + opt_encoded_size!(&self.hash, out); + opt_encoded_size!(&self.seek, out); + opt_encoded_size!(&self.upgrade, out); + Ok(out) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = map_encode!(rest, self.request, self.fork); + + let rest = opt_encoded_bytes!(&self.block, rest); + let rest = opt_encoded_bytes!(&self.hash, rest); + let rest = opt_encoded_bytes!(&self.seek, rest); + let rest = opt_encoded_bytes!(&self.upgrade, rest); + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((request, fork), rest) = map_decode!(rest, [u64, u64]); + let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, DataUpgrade, rest); + Ok(( + Data { + request, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) } } impl Data { - /// Transform Data message into a Proof emptying fields - pub fn into_proof(&mut self) -> Proof { + /// Transform Data message into a [`Proof`] + pub fn into_proof(self) -> Proof { Proof { fork: self.fork, - block: self.block.take(), - hash: self.hash.take(), - seek: self.seek.take(), - upgrade: self.upgrade.take(), + block: self.block, + hash: self.hash, + seek: self.seek, + upgrade: self.upgrade, } } } @@ -367,18 +408,21 @@ pub struct NoData { pub request: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) +impl CompactEncoding for NoData { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.request)) } - fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.request)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(NoData { request }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Self { request }, rest)) } } @@ -390,21 +434,22 @@ pub struct Want { /// Length pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Want) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) + +impl CompactEncoding for Want { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.length)) } - fn encode(&mut self, value: &Want, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.length)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Want { start, length }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -416,21 +461,22 @@ pub struct Unwant { /// Length pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Unwant) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) + +impl CompactEncoding for Unwant { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.length)) } - fn encode(&mut self, value: &Unwant, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.length)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Unwant { start, length }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -442,21 +488,21 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Bitfield) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.bitfield) +impl CompactEncoding for Bitfield { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.bitfield)) } - fn encode(&mut self, value: &Bitfield, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.bitfield, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.bitfield)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let bitfield: Vec = self.decode(buffer)?; - Ok(Bitfield { start, bitfield }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, bitfield), rest) = map_decode!(buffer, [u64, Vec]); + Ok((Self { start, bitfield }, rest)) } } @@ -473,41 +519,46 @@ pub struct Range { pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Range) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.start)?; - if value.length != 1 { - self.preencode(&value.length)?; +impl CompactEncoding for Range { + fn encoded_size(&self) -> Result { + let mut out = 1 + sum_encoded_size!(self.start); + if self.length != 1 { + out += self.length.encoded_size()?; } - Ok(self.end()) + Ok(out) } - fn encode(&mut self, value: &Range, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.drop { 1 } else { 0 }; - flags |= if value.length == 1 { 2 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.start, buffer)?; - if value.length != 1 { - self.encode(&value.length, buffer)?; + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.drop { 1 } else { 0 }; + flags |= if self.length == 1 { 2 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = self.start.encode(rest)?; + if self.length != 1 { + return self.length.encode(rest); } - Ok(self.end()) + Ok(rest) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let start: u64 = self.decode(buffer)?; + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (start, rest) = u64::decode(rest)?; let drop = flags & 1 != 0; - let length: u64 = if flags & 2 != 0 { - 1 + let (length, rest) = if flags & 2 != 0 { + (1, rest) } else { - self.decode(buffer)? + u64::decode(rest)? }; - Ok(Range { - drop, - length, - start, - }) + Ok(( + Range { + drop, + length, + start, + }, + rest, + )) } } @@ -519,20 +570,20 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Extension) -> Result { - self.preencode(&value.name)?; - self.preencode_raw_buffer(&value.message) +impl CompactEncoding for Extension { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.name, self.message)) } - fn encode(&mut self, value: &Extension, buffer: &mut [u8]) -> Result { - self.encode(&value.name, buffer)?; - self.encode_raw_buffer(&value.message, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.name, self.message)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let name: String = self.decode(buffer)?; - let message: Vec = self.decode_raw_buffer(buffer)?; - Ok(Extension { name, message }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((name, message), rest) = map_decode!(buffer, [String, Vec]); + Ok((Self { name, message }, rest)) } } diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..29473bc --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,108 @@ +//! Type-erased bidirectional byte stream for the protocol. +//! +//! This module provides [`BoxedStream`], which hides the concrete stream type +//! from the public API while still supporting any `Stream> + Sink>`. + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Sink, Stream}; + +/// A type-erased bidirectional byte stream for protocol communication. +/// +/// This wrapper allows [`Protocol`](crate::Protocol) to have a non-generic public interface +/// while still accepting any stream type that implements the required traits. +pub struct BoxedStream { + inner: Box, +} + +impl std::fmt::Debug for BoxedStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BoxedStream").finish_non_exhaustive() + } +} + +/// Internal trait combining Stream + Sink operations for type erasure. +trait StreamSink: Send + Sync { + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>>; + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + fn start_send(&mut self, item: Vec) -> io::Result<()>; + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll>; + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll>; +} + +/// Wrapper to implement StreamSink for any compatible type. +struct StreamSinkWrapper(S); + +impl StreamSink for StreamSinkWrapper +where + S: Stream> + Sink> + Unpin + Send + Sync, + >>::Error: Into, +{ + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + Pin::new(&mut self.0).poll_next(cx) + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) + } + + fn start_send(&mut self, item: Vec) -> io::Result<()> { + Pin::new(&mut self.0).start_send(item).map_err(Into::into) + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx).map_err(Into::into) + } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_close(cx).map_err(Into::into) + } +} + +impl BoxedStream { + /// Create a new `BoxedStream` from any compatible stream. + /// + /// The stream must implement: + /// - `Stream>` for receiving messages + /// - `Sink>` for sending messages + /// - `Unpin` and `Send` + pub fn new(stream: S) -> Self + where + S: Stream> + Sink> + Unpin + Send + Sync + 'static, + >>::Error: Into, + { + BoxedStream { + inner: Box::new(StreamSinkWrapper(stream)), + } + } +} + +impl Stream for BoxedStream { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_next(cx) + } +} + +impl Sink> for BoxedStream { + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + self.inner.start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_close(cx) + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..4512542 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,204 @@ +#![expect(unused)] +use std::{ + io::{self}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{ + Sink, Stream, StreamExt, + channel::mpsc::{UnboundedReceiver as Receiver, UnboundedSender as Sender, unbounded}, +}; + +#[derive(Debug)] +pub(crate) struct Io { + receiver: Receiver>, + sender: Sender>, +} + +impl Default for Io { + fn default() -> Self { + let (sender, receiver) = unbounded(); + Self { sender, receiver } + } +} + +impl Stream for Io { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.receiver).poll_next(cx) + } +} + +impl Sink> for Io { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Pin::new(&mut self.sender) + .start_send(item) + .map_err(|_e| io::Error::other("SendError")) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +#[derive(Default, Debug)] +pub(crate) struct TwoWay { + l_to_r: Io, + r_to_l: Io, +} + +impl TwoWay { + fn split_sides(self) -> (Io, Io) { + let left = Io { + sender: self.l_to_r.sender, + receiver: self.r_to_l.receiver, + }; + let right = Io { + sender: self.r_to_l.sender, + receiver: self.l_to_r.receiver, + }; + (left, right) + } +} + +pub(crate) fn log() { + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); + START_LOGS.get_or_init(|| { + use tracing_subscriber::{ + EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_thread_ids(false) + //.with_thread_names(true) + //.with_span_modes(true) + ; + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) + .init(); + }); +} + +pub(crate) struct Moo { + receiver: Rx, + sender: Tx, +} + +impl + Unpin, Tx: Unpin> Stream for Moo { + type Item = RxItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.receiver).poll_next(cx) + } +} + +impl + Unpin> Sink + for Moo +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { + let this = self.get_mut(); + Pin::new(&mut this.sender) + .start_send(item) + .map_err(|_e| io::Error::other("SendError")) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +/// Creaee [`Moo`] from return value of [`unbounded`] +impl From<(Tx, Rx)> for Moo { + fn from(value: (Tx, Rx)) -> Self { + Moo { + receiver: value.1, + sender: value.0, + } + } +} + +impl Moo { + /// connect two [`Moo`]s + fn connect( + self, + other: Moo, + ) -> (Moo, Moo) { + let left = Moo { + receiver: self.receiver, + sender: other.sender, + }; + let right = Moo { + receiver: other.receiver, + sender: self.sender, + }; + (left, right) + } +} + +fn result_channel() -> (Sender>, impl Stream>>) { + let (tx, rx) = unbounded::>(); + (tx, rx.map(Ok)) +} + +#[expect(clippy::type_complexity)] +pub(crate) fn create_result_connected() -> ( + Moo>>, impl Sink>>, + Moo>>, impl Sink>>, +) { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + a.connect(b) +} + +#[cfg(test)] +mod test { + #[tokio::test] + async fn way_one() { + use futures::{SinkExt, StreamExt}; + let mut a = super::Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); + } + + #[tokio::test] + async fn split() { + use futures::{SinkExt, StreamExt}; + let (mut left, mut right) = (super::TwoWay::default()).split_sides(); + left.send(b"hello".to_vec()).await.unwrap(); + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); + } +} diff --git a/src/util.rs b/src/util.rs index c99ff9c..83b9d1c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,12 +1,13 @@ use blake2::{ - digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, + digest::{FixedOutput, Update, typenum::U32}, +}; +use std::{ + convert::TryInto, + io::{Error, ErrorKind}, }; -use std::convert::TryInto; -use std::io::{Error, ErrorKind}; -use crate::constants::DISCOVERY_NS_BUF; -use crate::DiscoveryKey; +use crate::{DiscoveryKey, constants::DISCOVERY_NS_BUF}; /// Calculate the discovery key of a key. /// @@ -27,32 +28,3 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { format!("Cannot forward on channel: {err}"), ) } - -pub(crate) const UINT_24_LENGTH: usize = 3; - -#[inline] -pub(crate) fn wrap_uint24_le(data: &Vec) -> Vec { - let mut buf: Vec = vec![0; 3]; - let n = data.len(); - write_uint24_le(n, &mut buf); - buf.extend(data); - buf -} - -#[inline] -pub(crate) fn write_uint24_le(n: usize, buf: &mut [u8]) { - buf[0] = (n & 255) as u8; - buf[1] = ((n >> 8) & 255) as u8; - buf[2] = ((n >> 16) & 255) as u8; -} - -#[inline] -pub(crate) fn stat_uint24_le(buffer: &[u8]) -> Option<(usize, u64)> { - if buffer.len() >= 3 { - let len = - ((buffer[0] as u32) | ((buffer[1] as u32) << 8) | ((buffer[2] as u32) << 16)) as u64; - Some((UINT_24_LENGTH, len)) - } else { - None - } -} diff --git a/src/writer.rs b/src/writer.rs deleted file mode 100644 index e3cc5da..0000000 --- a/src/writer.rs +++ /dev/null @@ -1,173 +0,0 @@ -use crate::crypto::EncryptCipher; -use crate::message::{Encoder, Frame}; - -use futures_lite::{ready, AsyncWrite}; -use std::collections::VecDeque; -use std::fmt; -use std::io::Result; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const BUF_SIZE: usize = 1024 * 64; - -#[derive(Debug)] -pub(crate) enum Step { - Flushing, - Writing, - Processing, -} - -pub(crate) struct WriteState { - queue: VecDeque, - buf: Vec, - current_frame: Option, - start: usize, - end: usize, - cipher: Option, - step: Step, -} - -impl fmt::Debug for WriteState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WriteState") - .field("queue (len)", &self.queue.len()) - .field("step", &self.step) - .field("buf (len)", &self.buf.len()) - .field("current_frame", &self.current_frame) - .field("start", &self.start) - .field("end", &self.end) - .field("cipher", &self.cipher.is_some()) - .finish() - } -} - -impl WriteState { - pub(crate) fn new() -> Self { - Self { - queue: VecDeque::new(), - buf: vec![0u8; BUF_SIZE], - current_frame: None, - start: 0, - end: 0, - cipher: None, - step: Step::Processing, - } - } - - pub(crate) fn queue_frame(&mut self, frame: F) - where - F: Into, - { - self.queue.push_back(frame.into()) - } - - pub(crate) fn try_queue_direct(&mut self, frame: &mut T) -> Result { - let promised_len = frame.encoded_len()?; - let padded_promised_len = self.safe_encrypted_len(promised_len); - if self.buf.len() < padded_promised_len { - self.buf.resize(padded_promised_len, 0u8); - } - if padded_promised_len > self.remaining() { - return Ok(false); - } - let actual_len = frame.encode(&mut self.buf[self.end..])?; - if actual_len != promised_len { - panic!( - "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" - ); - } - self.advance(padded_promised_len)?; - Ok(true) - } - - pub(crate) fn can_park_frame(&self) -> bool { - self.current_frame.is_none() - } - - pub(crate) fn park_frame(&mut self, frame: F) - where - F: Into, - { - if self.current_frame.is_none() { - self.current_frame = Some(frame.into()) - } - } - - fn advance(&mut self, n: usize) -> Result<()> { - let end = self.end + n; - - let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.end + cipher.encrypt(&mut self.buf[self.end..end])? - } else { - end - }; - - self.end = encrypted_end; - Ok(()) - } - - pub(crate) fn upgrade_with_encrypt_cipher(&mut self, encrypt_cipher: EncryptCipher) { - self.cipher = Some(encrypt_cipher); - } - - fn remaining(&self) -> usize { - self.buf.len() - self.end - } - - fn pending(&self) -> usize { - self.end - self.start - } - - pub(crate) fn poll_send( - &mut self, - cx: &mut Context<'_>, - mut writer: &mut W, - ) -> Poll> - where - W: AsyncWrite + Unpin, - { - loop { - self.step = match self.step { - Step::Processing => { - if self.current_frame.is_none() && !self.queue.is_empty() { - self.current_frame = self.queue.pop_front(); - } - - if let Some(mut frame) = self.current_frame.take() { - if !self.try_queue_direct(&mut frame)? { - self.current_frame = Some(frame); - } - } - - if self.pending() == 0 { - return Poll::Ready(Ok(())); - } - Step::Writing - } - Step::Writing => { - let n = ready!( - Pin::new(&mut writer).poll_write(cx, &self.buf[self.start..self.end]) - )?; - self.start += n; - if self.start == self.end { - self.start = 0; - self.end = 0; - } - Step::Flushing - } - Step::Flushing => { - ready!(Pin::new(&mut writer).poll_flush(cx))?; - Step::Processing - } - } - } - } - - fn safe_encrypted_len(&self, encoded_len: usize) -> usize { - if let Some(cipher) = &self.cipher { - cipher.safe_encrypted_len(encoded_len) - } else { - encoded_len - } - } -} diff --git a/tests/_util.rs b/tests/_util.rs index 9d0f9bf..4127b42 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,38 +1,46 @@ -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; +#![allow(unused)] +use async_compat::CompatExt; +use futures_lite::StreamExt; +use hypercore_handshake::{Cipher, state_machine::SecStream}; +use hypercore_protocol::{Channel, DiscoveryKey, Event, Protocol}; use instant::Duration; use std::io; +use tokio::task::JoinHandle; +use uint24le_framing::Uint24LELengthPrefixedFraming; -pub type MemoryProtocol = Protocol>; -pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { - let (ar, bw) = sluice::pipe::pipe(); - let (br, aw) = sluice::pipe::pipe(); +/// Create a connected pair of test streams. +pub fn create_connected_streams() -> (Cipher, Cipher) { + let (left, right) = tokio::io::duplex(10_000); + let left = Uint24LELengthPrefixedFraming::new(left.compat()); + let right = Uint24LELengthPrefixedFraming::new(right.compat()); - let a = ProtocolBuilder::new(true); - let b = ProtocolBuilder::new(false); - let a = a.connect_rw(ar, aw); - let b = b.connect_rw(br, bw); - Ok((a, b)) + let kp = hypercore_handshake::state_machine::hc_specific::generate_keypair().unwrap(); + + let initiator = Cipher::new_init( + Box::new(left), + SecStream::new_initiator_ik(&kp.public.clone().try_into().unwrap(), &[]).unwrap(), + ); + let responder = Cipher::new_resp( + Box::new(right), + SecStream::new_responder_ik(&kp, &[]).unwrap(), + ); + dbg!(responder.get_remote_static()); + + (initiator, responder) } -pub type TcpProtocol = Protocol; -pub async fn create_pair_tcp() -> io::Result<(TcpProtocol, TcpProtocol)> { - let (stream_a, stream_b) = tcp::pair().await?; - let a = ProtocolBuilder::new(true).connect(stream_a); - let b = ProtocolBuilder::new(false).connect(stream_b); - Ok((a, b)) +/// Create a connected pair of protocols for testing. +pub fn create_pair() -> (Protocol, Protocol) { + let (stream_a, stream_b) = create_connected_streams(); + + let proto_a = Protocol::new(Box::new(stream_a)); + let proto_b = Protocol::new(Box::new(stream_b)); + + (proto_a, proto_b) } -pub fn next_event( - mut proto: Protocol, -) -> impl Future, io::Result)> -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - task::spawn(async move { +pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> { + tokio::task::spawn(async move { let e1 = proto.next().await; let e1 = e1.unwrap(); (proto, e1) @@ -56,13 +64,8 @@ pub fn event_channel(event: Event) -> Channel { } /// Drive a protocol stream until the first channel arrives. -pub fn drive_until_channel( - mut proto: Protocol, -) -> JoinHandle, Channel)>> -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - task::spawn(async move { +pub fn drive_until_channel(mut proto: Protocol) -> JoinHandle> { + tokio::task::spawn(async move { while let Some(event) = proto.next().await { let event = event?; if let Event::Channel(channel) = event { @@ -76,34 +79,15 @@ where }) } -pub mod tcp { - use async_std::net::{TcpListener, TcpStream}; - use async_std::prelude::*; - use async_std::task; - use std::io::{Error, ErrorKind, Result}; - pub async fn pair() -> Result<(TcpStream, TcpStream)> { - let address = "localhost:9999"; - let listener = TcpListener::bind(&address).await?; - let mut incoming = listener.incoming(); - - let connect_task = task::spawn(async move { TcpStream::connect(&address).await }); - - let server_stream = incoming.next().await; - let server_stream = - server_stream.ok_or_else(|| Error::new(ErrorKind::Other, "Stream closed"))?; - let server_stream = server_stream?; - let client_stream = connect_task.await?; - Ok((server_stream, client_stream)) - } -} - -const RETRY_TIMEOUT: u64 = 100_u64; -const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { + use async_std::net::TcpStream; + const RETRY_TIMEOUT: u64 = 100_u64; + const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; loop { let timeout = async_std::future::timeout( Duration::from_millis(NO_RESPONSE_TIMEOUT), - TcpStream::connect(format!("localhost:{}", port)), + TcpStream::connect(format!("localhost:{port}")), ) .await; if timeout.is_err() { diff --git a/tests/basic.rs b/tests/basic.rs index 8a99c7e..0618d74 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,32 +1,37 @@ -#![allow(dead_code, unused_imports)] - -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; -use hypercore_protocol::{schema::*, DiscoveryKey}; -use std::io; -use test_log::test; +use _util::{create_pair, drive_until_channel, event_channel, event_discovery_key, next_event}; +use futures_lite::StreamExt; +use hypercore_protocol::{DiscoveryKey, Event, Message, discovery_key, schema::*}; +use std::{io, time::Duration}; +use tokio::task; mod _util; -use _util::*; -#[test(async_std::test)] +#[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - // env_logger::init(); - let (proto_a, proto_b) = create_pair_memory().await?; + let (mut proto_a, mut proto_b) = create_pair(); - let next_a = next_event(proto_a); - let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await; - let (proto_b, event_b) = next_b.await; + let next_a = tokio::task::spawn(async move { + let e1 = proto_a.next().await; + let e1 = e1.unwrap(); + _ = tokio::time::timeout(Duration::from_millis(200), proto_a.next()).await; + + (proto_a, e1) + }); + + let next_b = tokio::task::spawn(async move { + let e1 = proto_b.next().await; + let e1 = e1.unwrap(); + (proto_b, e1) + }); + + let (proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); - assert_eq!(proto_a.public_key(), proto_b.remote_public_key()); - assert_eq!(proto_b.public_key(), proto_a.remote_public_key()); + assert_eq!(proto_a.public_key(), proto_b.remote_public_key().unwrap()); + assert_eq!(proto_b.public_key(), proto_a.remote_public_key().unwrap()); let key = [3u8; 32]; @@ -35,18 +40,18 @@ async fn basic_protocol() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_b, event_b) = next_b.await; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::DiscoveryKey(_)))); assert_eq!(event_discovery_key(event_b.unwrap()), discovery_key(&key)); proto_b.open(key).await?; let next_b = next_event(proto_b); - let (proto_b, event_b) = next_b.await; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::Channel(_)))); let mut channel_b = event_channel(event_b.unwrap()); - let (proto_a, event_a) = next_a.await; + let (proto_a, event_a) = next_a.await?; assert!(matches!(event_a, Ok(Event::Channel(_)))); let mut channel_a = event_channel(event_a.unwrap()); @@ -61,15 +66,14 @@ async fn basic_protocol() -> anyhow::Result<()> { let channel_event_b = channel_b.next().await; assert_eq!(channel_event_b, Some(want(0, 10))); - // eprintln!("channel_event_b: {:?}", channel_event_b); let channel_event_a = channel_a.next().await; assert_eq!(channel_event_a, Some(want(10, 5))); channel_a.close().await?; - let (_, event_a) = next_a.await; - let (_, event_b) = next_b.await; + let (_, event_a) = next_a.await?; + let (_, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Close(_)))); assert!(matches!(event_b, Ok(Event::Close(_)))); @@ -78,9 +82,9 @@ async fn basic_protocol() -> anyhow::Result<()> { Ok(()) } -#[test(async_std::test)] +#[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { - let (mut proto_a, mut proto_b) = create_pair_memory().await?; + let (proto_a, proto_b) = create_pair(); let key1 = [0u8; 32]; let key2 = [1u8; 32]; @@ -91,8 +95,9 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (mut proto_a, mut channel_a1) = next_a.await?; - let (mut proto_b, mut channel_b1) = next_b.await?; + let (proto_a, channel_a1) = next_a.await??; + + let (proto_b, channel_b1) = next_b.await??; proto_a.open(key2).await?; proto_b.open(key2).await?; @@ -100,8 +105,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a2) = next_a.await?; - let (proto_b, mut channel_b2) = next_b.await?; + let (proto_a, mut channel_a2) = next_a.await??; + let (proto_b, mut channel_b2) = next_b.await??; eprintln!( "got channels: {:?}", @@ -119,8 +124,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, ev_a) = next_a.await; - let (mut proto_b, ev_b) = next_b.await; + let (mut proto_a, ev_a) = next_a.await?; + let (mut proto_b, ev_b) = next_b.await?; let ev_a = ev_a?; let ev_b = ev_b?; eprintln!("next a: {ev_a:?}"); @@ -165,7 +170,6 @@ async fn open_close_channels() -> anyhow::Result<()> { assert_eq!(msg_b, Some(want(0, 10))); eprintln!("all good!"); - Ok(()) } diff --git a/tests/js/.gitignore b/tests/js/.gitignore new file mode 100644 index 0000000..cae21dd --- /dev/null +++ b/tests/js/.gitignore @@ -0,0 +1,2 @@ +yarn.lock +node_modules_for_tests diff --git a/tests/js/mod.rs b/tests/js/mod.rs index 8894b3d..857fd20 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -1,15 +1,11 @@ use anyhow::Result; use instant::Duration; -use std::fs::{create_dir_all, remove_dir_all, remove_file}; -use std::path::Path; -use std::process::Command; - -#[cfg(feature = "async-std")] -use async_std::{ - process, - task::{self, sleep}, +use std::{ + fs::{create_dir_all, remove_dir_all, remove_file}, + path::Path, + process::Command, }; -#[cfg(feature = "tokio")] + use tokio::{process, task, time::sleep}; use crate::_util::wait_for_localhost_port; @@ -41,9 +37,9 @@ pub fn install() { } pub fn prepare_test_set(test_set: &str) -> (String, String, String) { - let path_result = format!("tests/js/work/{}/result.txt", test_set); - let path_writer = format!("tests/js/work/{}/writer", test_set); - let path_reader = format!("tests/js/work/{}/reader", test_set); + let path_result = format!("tests/js/work/{test_set}/result.txt"); + let path_writer = format!("tests/js/work/{test_set}/writer"); + let path_reader = format!("tests/js/work/{test_set}/reader"); create_dir_all(&path_writer).expect("Unable to create work writer directory"); create_dir_all(&path_reader).expect("Unable to create work reader directory"); (path_result, path_writer, path_reader) @@ -98,28 +94,13 @@ impl JavascriptServer { assert_eq!( Some(0), code, - "node server did not exit successfully, is_writer={}, port={}, data_count={}, data_size={}, data_char={}, test_set={}", - is_writer, - port, - data_count, - data_size, - data_char, - test_set, + "node server did not exit successfully, is_writer={is_writer}, port={port}, data_count={data_count}, data_size={data_size}, data_char={data_char}, test_set={test_set}", ); })); wait_for_localhost_port(port).await; } } -impl Drop for JavascriptServer { - fn drop(&mut self) { - #[cfg(feature = "async-std")] - if let Some(handle) = self.handle.take() { - async_std::task::block_on(handle.cancel()); - } - } -} - pub async fn js_start_server( is_writer: bool, port: u32, diff --git a/tests/js/package.json b/tests/js/package.json index c3a57ff..56fe846 100644 --- a/tests/js/package.json +++ b/tests/js/package.json @@ -2,6 +2,7 @@ "name": "hypercore-protocol-rs-js-interop-tests", "version": "0.0.1", "dependencies": { - "hypercore": "10.31.12" + "hypercore": "10.31.12", + "udx-native": "^1.0.0" } } diff --git a/tests/js_interop.rs b/tests/js_interop.rs index d703734..cb0b962 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,44 +1,41 @@ +// These tests require the old ProtocolBuilder API which performed its own Noise handshake. +// They are disabled until updated to work with hyperswarm pre-encrypted connections. +#![cfg(feature = "js_tests")] + +pub mod _util; +#[path = "../src/test_utils.rs"] +mod test_utils; + use _util::wait_for_localhost_port; use anyhow::Result; +use async_compat::CompatExt; use futures::Future; use futures_lite::stream::StreamExt; -use hypercore::SigningKey; use hypercore::{ - Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, - VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH, + Hypercore, HypercoreBuilder, PUBLIC_KEY_LENGTH, PartialKeypair, SECRET_KEY_LENGTH, SigningKey, + Storage, VerifyingKey, +}; +use hypercore_protocol::{ + Channel, Event, Message, Protocol, discovery_key, + schema::{Data, Range, Request, Synchronize}, }; +use hypercore_schema::{RequestBlock, RequestUpgrade}; use instant::Duration; -use std::fmt::Debug; -use std::path::Path; -use std::sync::Arc; -use std::sync::Once; - -#[cfg(feature = "tokio")] -use async_compat::CompatExt; -#[cfg(feature = "async-std")] -use async_std::{ - fs::{metadata, File}, - io::{prelude::BufReadExt, BufReader, BufWriter, WriteExt}, - net::{TcpListener, TcpStream}, - sync::Mutex, - task::{self, sleep}, - test as async_test, +use std::{ + fmt::Debug, + path::Path, + sync::{Arc, Once}, }; -use test_log::test; -#[cfg(feature = "tokio")] use tokio::{ - fs::{metadata, File}, + fs::{File, metadata}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::{TcpListener, TcpStream}, sync::Mutex, - task, test as async_test, + task, time::sleep, }; +use tracing::instrument; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; - -pub mod _util; mod js; use js::{cleanup, install, js_run_client, js_start_server, prepare_test_set}; @@ -49,6 +46,7 @@ fn init() { cleanup(); install(); }); + test_utils::log(); } const TEST_SET_NODE_CLIENT_NODE_SERVER: &str = "ncns"; @@ -59,65 +57,64 @@ const TEST_SET_SERVER_WRITER: &str = "sw"; const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_server_writer() -> Result<()> { - js_interop_ncns_simple(true, 8101).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncns_server_writer() -> Result<()> { + ncns(true, 8101).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_client_writer() -> Result<()> { - js_interop_ncns_simple(false, 8102).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncns_client_writer() -> Result<()> { + ncns(false, 8102).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_simple_server_writer() -> Result<()> { - js_interop_rcns_simple(true, 8103).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcns_server_writer() -> Result<()> { + rcns(true, 8103).await?; Ok(()) } -#[test(async_test)] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_simple_client_writer() -> Result<()> { - js_interop_rcns_simple(false, 8104).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcns_client_writer() -> Result<()> { + rcns(false, 8104).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_server_writer() -> Result<()> { - js_interop_ncrs_simple(true, 8105).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncrs_server_writer() -> Result<()> { + ncrs(true, 8105).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_client_writer() -> Result<()> { - js_interop_ncrs_simple(false, 8106).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncrs_client_writer() -> Result<()> { + ncrs(false, 8106).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_simple_server_writer() -> Result<()> { - js_interop_rcrs_simple(true, 8107).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcrs_server_writer() -> Result<()> { + rcrs(true, 8107).await?; Ok(()) } -#[test(async_test)] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_simple_client_writer() -> Result<()> { - js_interop_rcrs_simple(false, 8108).await?; +#[tokio::test] +//#[cfg_attr(not(feature = "js_tests"), ignore)] +//#[ignore] // FIXME this tests hangs sporadically +async fn rcrs_client_writer() -> Result<()> { + rcrs(false, 8108).await?; Ok(()) } -async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -156,7 +153,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -201,7 +198,7 @@ async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -247,7 +244,7 @@ async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -322,7 +319,7 @@ async fn assert_result( let expected_value = data_char.to_string().repeat(item_size); let mut line = String::new(); while reader.read_line(&mut line).await? != 0 { - assert_eq!(line, format!("{} {}\n", i, expected_value)); + assert_eq!(line, format!("{i} {expected_value}\n")); i += 1; line = String::new(); } @@ -441,63 +438,50 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { PartialKeypair { public, secret } } -#[cfg(feature = "async-std")] +#[instrument(skip_all)] async fn on_replication_connection( stream: TcpStream, is_initiator: bool, hypercore: Arc, ) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - while let Some(event) = protocol.next().await { - let event = event?; - match event { - Event::Handshake(_) => { - if is_initiator { - protocol.open(*hypercore.key()).await?; - } - } - Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { - protocol.open(*hypercore.key()).await?; - } else { - panic!("Invalid discovery key"); - } - } - Event::Channel(channel) => { - hypercore.on_replication_peer(channel); - } - Event::Close(_dkey) => { - break; - } - _ => {} - } - } - Ok(()) -} + use hypercore_handshake::{Cipher, state_machine::SecStream}; + use tracing::info; + use uint24le_framing::Uint24LELengthPrefixedFraming; -#[cfg(feature = "tokio")] -async fn on_replication_connection( - stream: TcpStream, - is_initiator: bool, - hypercore: Arc, -) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); + let framed = Uint24LELengthPrefixedFraming::new(stream.compat()); + + let cipher = if is_initiator { + let ss = SecStream::new_initiator_xx(&[])?; + Cipher::new(Some(Box::new(framed)), ss.into()) + } else { + let keypair = hypercore_handshake::state_machine::hc_specific::generate_keypair().unwrap(); + let ss = SecStream::new_responder_xx(&keypair, &[])?; + Cipher::new(Some(Box::new(framed)), ss.into()) + }; + + let mut protocol = Protocol::new(Box::new(cipher)); + let mut channel_opened = false; while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { - if is_initiator { + info!("Event::Handshake"); + if is_initiator && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } } Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { + info!("Event::DiscoveryKey"); + if hypercore.discovery_key == dkey && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } else { panic!("Invalid discovery key"); } } Event::Channel(channel) => { + info!("Event::Channel is_initiator = {is_initiator}"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => { @@ -647,6 +631,8 @@ async fn on_replication_message( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -758,6 +744,8 @@ async fn on_replication_message( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } let exit = if synced { @@ -766,8 +754,11 @@ async fn on_replication_message( let mut writer = BufWriter::new(File::create(result_path).await?); for i in 0..new_info.contiguous_length { let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); - let line = format!("{} {}\n", i, value); - writer.write(line.as_bytes()).await?; + let line = format!("{i} {value}\n"); + let n_written = writer.write(line.as_bytes()).await?; + if line.len() != n_written { + panic!("Couldn't write all write all bytse"); + } } writer.flush().await?; true @@ -796,7 +787,7 @@ async fn on_replication_message( } } _ => { - panic!("Received unexpected message {:?}", message); + panic!("Received unexpected message {message:?}"); } }; Ok(false) @@ -847,40 +838,6 @@ impl RustServer { } } -impl Drop for RustServer { - fn drop(&mut self) { - #[cfg(feature = "async-std")] - if let Some(handle) = self.handle.take() { - task::block_on(handle.cancel()); - } - } -} - -#[cfg(feature = "async-std")] -pub async fn tcp_server( - port: u32, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; - let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { - let context = context.clone(); - let _peer_addr = stream.peer_addr().unwrap(); - task::spawn(async move { - onconnection(stream, false, context) - .await - .expect("Should return ok"); - }); - } - Ok(()) -} - -#[cfg(feature = "tokio")] pub async fn tcp_server( port: u32, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, @@ -890,7 +847,7 @@ where F: Future> + Send, C: Clone + Send + 'static, { - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; + let listener = TcpListener::bind(&format!("localhost:{port}")).await?; while let Ok((stream, _peer_address)) = listener.accept().await { let context = context.clone(); @@ -912,6 +869,6 @@ where F: Future> + Send, C: Clone + Send + 'static, { - let stream = TcpStream::connect(&format!("localhost:{}", port)).await?; + let stream = TcpStream::connect(&format!("localhost:{port}")).await?; onconnection(stream, true, context).await } From fb8b3444fe339bbe0f5d814d3f2d32610aee57e4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 19 Feb 2026 14:21:41 -0500 Subject: [PATCH 2/3] 7874b96: CHANGELOG.md --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3f1044..d017ea5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,10 @@ All notable changes to this Rust implementation of hypercore-protocol will be do ### unreleased -* TODO: Add changes here as they happen +BIG CHANGES: +* Encryption and framing of streams has been moved out of this crate into `hypercore_handshake` and `uint24le_framing` respectively. This had big impacts on the public API. Now `Protocol::new` just takes a `impl CipherTrait` argument. +* Remove dependence on `hypercore` instead we use `hypercore_schema`. +* Bumped to edition 2024. ### 0.6.1 From 40d139321d3c49518299e55885a1d8cbf53ff83e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 19 Feb 2026 16:42:44 -0500 Subject: [PATCH 3/3] Remove ping/keepalive it's handled in hypercore handshake --- src/constants.rs | 3 --- src/protocol.rs | 25 +------------------------ 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 1efbbed..be6ffac 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,8 +1,5 @@ /// Seed for the discovery key hash pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; -/// Default keepalive interval (in seconds) -pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; - /// v10: Protocol name pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; diff --git a/src/protocol.rs b/src/protocol.rs index 3c716a9..74c3d2e 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,6 +1,5 @@ use async_channel::{Receiver, Sender}; use futures_lite::stream::Stream; -use futures_timer::Delay; use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN}; use std::{ collections::VecDeque, @@ -9,13 +8,12 @@ use std::{ io::{self, Result}, pin::Pin, task::{Context, Poll}, - time::Duration, }; use tracing::{error, instrument}; use crate::{ channels::{Channel, ChannelMap}, - constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}, + constants::PROTOCOL_NAME, crypto::HandshakeResult, message::{ChannelMessage, Message}, mqueue::MessageIo, @@ -102,8 +100,6 @@ pub struct Protocol { command_tx: CommandTx, outbound_rx: Receiver>, outbound_tx: Sender>, - #[allow(dead_code)] // TODO: Implement keepalive - keepalive: Delay, queued_events: VecDeque, handshake_emitted: bool, } @@ -141,7 +137,6 @@ impl Protocol { command_tx: CommandTx(command_tx), outbound_tx, outbound_rx, - keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), queued_events: VecDeque::new(), handshake_emitted: false, } @@ -219,9 +214,6 @@ impl Protocol { // Check for commands. return_error!(this.poll_commands(cx)); - // Poll the keepalive timer. - this.poll_keepalive(cx); - // Write everything we can write. return_error!(this.poll_outbound_write(cx)); @@ -244,21 +236,6 @@ impl Protocol { Ok(()) } - /// TODO Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&self, _cx: &mut Context<'_>) { - /* - const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); - if Pin::new(&mut self.keepalive).poll(cx).is_ready() { - if let State::Established = self.state { - // 24 bit header for the empty message, hence the 3 - self.write_state - .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); - } - self.keepalive.reset(KEEPALIVE_DURATION); - } - */ - } - // just handles Close and LocalSignal fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel.