diff --git a/Cargo.lock b/Cargo.lock index 09b0c144b78..fa7a9e36a85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,7 +509,10 @@ dependencies = [ "azure_core_test", "azure_identity", "azure_storage_blob_test", + "bytes", "futures", + "pin-project", + "rand 0.9.2", "serde", "serde_json", "tokio", diff --git a/eng/dict/rust-custom.txt b/eng/dict/rust-custom.txt index 174812db947..0e096afd207 100644 --- a/eng/dict/rust-custom.txt +++ b/eng/dict/rust-custom.txt @@ -12,6 +12,7 @@ rustflags rustls rustsec turbofish +uninit dylib cdylib staticlib diff --git a/sdk/storage/azure_storage_blob/Cargo.toml b/sdk/storage/azure_storage_blob/Cargo.toml index 56a9b27b0b8..10574dad54f 100644 --- a/sdk/storage/azure_storage_blob/Cargo.toml +++ b/sdk/storage/azure_storage_blob/Cargo.toml @@ -19,6 +19,9 @@ default = ["azure_core/default"] [dependencies] async-trait.workspace = true azure_core = { workspace = true, features = ["xml"] } +bytes.workspace = true +futures.workspace = true +pin-project.workspace = true serde.workspace = true serde_json.workspace = true typespec_client_core = { workspace = true, features = ["derive"] } @@ -35,6 +38,7 @@ azure_core_test = { workspace = true, features = [ azure_identity.workspace = true azure_storage_blob_test.path = "../azure_storage_blob_test" futures.workspace = true +rand.workspace = true tokio = { workspace = true, features = ["macros"] } tracing.workspace = true diff --git a/sdk/storage/azure_storage_blob/src/lib.rs b/sdk/storage/azure_storage_blob/src/lib.rs index 4cd17cd46a7..abf4db548b2 100644 --- a/sdk/storage/azure_storage_blob/src/lib.rs +++ b/sdk/storage/azure_storage_blob/src/lib.rs @@ -11,7 +11,9 @@ pub mod clients; #[allow(unused_imports)] mod generated; mod parsers; +mod partitioned_transfer; mod pipeline; +mod streams; pub use clients::*; pub use parsers::*; pub mod models; diff --git a/sdk/storage/azure_storage_blob/src/partitioned_transfer/mod.rs b/sdk/storage/azure_storage_blob/src/partitioned_transfer/mod.rs new file mode 100644 index 00000000000..85807166197 --- /dev/null +++ b/sdk/storage/azure_storage_blob/src/partitioned_transfer/mod.rs @@ -0,0 +1,250 @@ +use std::{cmp::max, future::Future, num::NonZero, pin::Pin}; + +use futures::{ + future::{self}, + Stream, TryStreamExt, +}; + +type AzureResult = azure_core::Result; + +/// Executes async operations from a queue with a concurrency limit. +/// +/// This function consumes a stream (`ops_queue`) of async operation factories (closures returning futures), +/// and runs up to `parallel` operations concurrently. As operations complete, new ones are started from the queue, +/// maintaining the concurrency limit. If any operation or queue item returns an error, the function returns early +/// with that error. When all operations and queue items are complete, returns `Ok(())`. +/// +/// # Parameters +/// - `ops_queue`: A stream yielding `Result TFut, TErr>`. Each item is either a closure producing a future, +/// or an error. The stream must be `Unpin`. +/// - `parallel`: The maximum number of operations to run concurrently. Must be non-zero. +/// +/// # Behavior +/// - Operations are scheduled as soon as possible, up to the concurrency limit. +/// - If an error is encountered in the queue or in any operation, the function returns that error immediately. +/// - When the queue is exhausted, waits for all running operations to complete before returning. +/// +/// # Errors +/// Returns the first error encountered from the queue or any operation. +/// +/// # Type Parameters +/// - `TFut`: Future type returned by each operation. +/// - `TErr`: Error type for queue or operation failures. +async fn run_all_with_concurrency_limit( + mut ops_queue: impl Stream Fut, Err>> + Unpin, + parallel: NonZero, +) -> Result<(), Err> +where + Fut: Future>, +{ + let parallel = parallel.get(); + + // if no real parallelism, take the simple option of executing ops sequentially. + // The "true" implementation can't handle parallel < 2. + if parallel == 1 { + while let Some(op) = ops_queue.try_next().await? { + op().await?; + } + return Ok(()); + } + + let first_op = match ops_queue.try_next().await? { + Some(item) => item, + None => return Ok(()), + }; + + let mut get_next_completed_op_future = future::select_all(vec![Box::pin(first_op())]); + let mut get_next_queue_op_future = ops_queue.try_next(); + loop { + // while max parallel running ops, focus on just running ops + get_next_completed_op_future = run_down(get_next_completed_op_future, parallel - 1).await?; + + match future::select(get_next_queue_op_future, get_next_completed_op_future).await { + future::Either::Left((Err(e), _)) => return Err(e), + future::Either::Right(((Err(e), _, _), _)) => return Err(e), + + // Next op in the queue arrived first. Add it to existing running ops. + future::Either::Left((Ok(next_op_in_queue), running_ops_fut)) => { + get_next_queue_op_future = ops_queue.try_next(); + get_next_completed_op_future = running_ops_fut; + + match next_op_in_queue { + Some(op) => { + get_next_completed_op_future = + combine_select_all(get_next_completed_op_future, Box::pin(op())); + } + // queue was finished, race is over + None => break, + } + } + // A running op completed first. Start another select_all with remaining running ops. + future::Either::Right(((Ok(_), _, remaining_running_ops), next_op_fut)) => { + // remaining_running_ops could be empty now. + // select panics on empty iter, so we can't race in this case. + // forcibly wait for next op in queue and handle it before continuing. + if remaining_running_ops.is_empty() { + let next_op = match next_op_fut.await? { + Some(item) => item, + None => return Ok(()), + }; + get_next_queue_op_future = ops_queue.try_next(); + get_next_completed_op_future = future::select_all(vec![Box::pin(next_op())]); + } else { + get_next_queue_op_future = next_op_fut; + get_next_completed_op_future = future::select_all(remaining_running_ops); + } + } + } + } + + let _ = future::try_join_all(get_next_completed_op_future.into_inner()).await?; + Ok(()) +} + +/// Loops `future::select_all()` with the existing `SelectAll`` until the target remaining +/// inner futures is reached. Will always leave at least one inner future remaining, for +/// type simplicity (select_all panics on len == 0); +async fn run_down( + select_fut: future::SelectAll>>, + target_remaining: usize, +) -> Result>>, Err> +where + Fut: Future>, +{ + let target_remaining = max(target_remaining, 1); + let mut select_vec = select_fut.into_inner(); + while select_vec.len() > target_remaining { + let result; + (result, _, select_vec) = future::select_all(select_vec).await; + result?; + } + Ok(future::select_all(select_vec)) +} + +/// Adds a pin-boxed future to an existing SelectAll of pin-boxed futures. +fn combine_select_all( + select_fut: future::SelectAll>>, + new_fut: Pin>, +) -> future::SelectAll>> +where + Fut: Future, +{ + let mut futures = select_fut.into_inner(); + futures.push(new_fut); + future::select_all(futures) +} + +#[cfg(test)] +mod tests { + use futures::{ready, FutureExt}; + + use super::*; + use std::{pin::Pin, sync::mpsc::channel, task::Poll, time::Duration}; + + #[tokio::test] + async fn enforce_concurrency_limit() -> AzureResult<()> { + let parallel = 4usize; + let num_ops = parallel + 1; + let wait_time_millis = 10u64; + let op_time_millis = wait_time_millis + 50; + + let (sender, receiver) = channel(); + + // setup a series of operations that send a unique number to a channel + // we can then assert the expected numbers made it to the channel at expected times + let ops = (0..num_ops).map(|i| { + let s = sender.clone(); + Ok(async move || { + s.send(i).unwrap(); + tokio::time::sleep(Duration::from_millis(op_time_millis)).await; + AzureResult::<()>::Ok(()) + }) + }); + + let race = future::select( + Box::pin(run_all_with_concurrency_limit( + futures::stream::iter(ops), + NonZero::new(parallel).unwrap(), + )), + Box::pin(tokio::time::sleep(Duration::from_millis(wait_time_millis))), + ) + .await; + match race { + future::Either::Left(_) => panic!("Wrong future won the race."), + future::Either::Right((_, run_all_fut)) => { + let mut items: Vec<_> = receiver.try_iter().collect(); + items.sort(); + assert_eq!(items, (0..parallel).collect::>()); + + run_all_fut.await?; + assert_eq!(receiver.try_iter().collect::>().len(), 1); + } + } + + Ok(()) + } + + #[tokio::test] + async fn handles_slow_stream() -> AzureResult<()> { + let parallel = 10; + let num_ops = 5; + let op_time_millis = 10; + let stream_time_millis = op_time_millis + 10; + // setup a series of operations that send a unique number to a channel + // we can then assert the expected numbers made it to the channel at expected times + let ops = (0..num_ops).map(|_| { + Ok(async move || { + tokio::time::sleep(Duration::from_millis(op_time_millis)).await; + AzureResult::<()>::Ok(()) + }) + }); + + run_all_with_concurrency_limit( + SlowStream::new(ops, Duration::from_millis(stream_time_millis)), + NonZero::new(parallel).unwrap(), + ) + .await + } + + #[tokio::test] + async fn success_when_no_ops() -> AzureResult<()> { + let parallel = 4usize; + + // not possible to manually type what we need + // make a vec with a concrete element and then remove it to get the desired typing + let op = || future::ready::>(Ok(())); + let mut ops = vec![Ok(op)]; + ops.pop(); + + run_all_with_concurrency_limit(futures::stream::iter(ops), NonZero::new(parallel).unwrap()) + .await + } + + struct SlowStream { + sleep: Pin>, + interval: Duration, + iter: Iter, + } + impl SlowStream { + fn new(iter: Iter, interval: Duration) -> Self { + Self { + sleep: Box::pin(tokio::time::sleep(interval)), + interval, + iter, + } + } + } + impl Stream for SlowStream { + type Item = Iter::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + ready!(this.sleep.poll_unpin(cx)); + this.sleep = Box::pin(tokio::time::sleep(this.interval)); + Poll::Ready(this.iter.next()) + } + } +} diff --git a/sdk/storage/azure_storage_blob/src/streams/mod.rs b/sdk/storage/azure_storage_blob/src/streams/mod.rs new file mode 100644 index 00000000000..083f9020968 --- /dev/null +++ b/sdk/storage/azure_storage_blob/src/streams/mod.rs @@ -0,0 +1 @@ +pub(crate) mod partitioned_stream; diff --git a/sdk/storage/azure_storage_blob/src/streams/partitioned_stream.rs b/sdk/storage/azure_storage_blob/src/streams/partitioned_stream.rs new file mode 100644 index 00000000000..8f7fbc7488c --- /dev/null +++ b/sdk/storage/azure_storage_blob/src/streams/partitioned_stream.rs @@ -0,0 +1,206 @@ +use pin_project::pin_project; +use std::{mem, num::NonZero, pin::Pin, slice, task::Poll}; + +use azure_core::stream::SeekableStream; +use bytes::{Bytes, BytesMut}; +use futures::{ready, stream::FusedStream, AsyncRead, Stream}; + +type AzureResult = azure_core::Result; + +#[pin_project] +pub(crate) struct PartitionedStream { + #[pin] + inner: Box, + buf: BytesMut, + partition_len: usize, + total_read: usize, + inner_complete: bool, +} + +impl PartitionedStream { + pub(crate) fn new(inner: Box, partition_len: NonZero) -> Self { + let partition_len = partition_len.get(); + Self { + buf: BytesMut::with_capacity(std::cmp::min(partition_len, inner.len())), + inner, + partition_len, + total_read: 0, + inner_complete: false, + } + } +} + +impl Stream for PartitionedStream { + type Item = AzureResult; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.project(); + + loop { + if *this.inner_complete || this.buf.len() >= *this.partition_len { + let ret = mem::replace( + this.buf, + BytesMut::with_capacity(std::cmp::min( + *this.partition_len, + this.inner.len() - *this.total_read, + )), + ); + return if ret.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(ret.freeze()))) + }; + } + + let spare_capacity = this.buf.spare_capacity_mut(); + let spare_capacity = unsafe { + // spare_capacity_mut() gives us the known remaining capacity of BytesMut. + // Those bytes are valid reserved memory but have had no values written + // to them. Those are the exact bytes we want to write into. + // MaybeUninit can be safely cast into u8, and so this pointer cast + // is safe. Since the spare capacity length is safely known, we can + // provide those to from_raw_parts without worry. + slice::from_raw_parts_mut( + spare_capacity.as_mut_ptr() as *mut u8, + spare_capacity.len(), + ) + }; + match ready!(this.inner.as_mut().poll_read(cx, spare_capacity)) { + Ok(bytes_read) => { + // poll_read() wrote bytes_read-many bytes into the spare capacity. + // those values are therefore initialized and we can add them to + // the existing buffer length + unsafe { this.buf.set_len(this.buf.len() + bytes_read) }; + *this.total_read += bytes_read; + *this.inner_complete = bytes_read == 0; + } + Err(e) => { + return Poll::Ready(Some(Err(e.into()))); + } + } + } + } +} + +impl FusedStream for PartitionedStream { + fn is_terminated(&self) -> bool { + self.inner_complete && self.buf.is_empty() + } +} + +#[cfg(test)] +mod tests { + use azure_core::stream::BytesStream; + use futures::TryStreamExt; + + use super::*; + + fn get_random_data(len: usize) -> Vec { + let mut data: Vec = vec![0; len]; + rand::fill(&mut data[..]); + data + } + + #[tokio::test] + async fn partitions_exact_len() -> AzureResult<()> { + for part_count in [2usize, 3, 11, 16] { + for part_len in [1024usize, 1000, 9999, 1] { + let data = get_random_data(part_len * part_count); + let stream = PartitionedStream::new( + Box::new(BytesStream::new(data.clone())), + NonZero::new(part_len).unwrap(), + ); + + let parts: Vec<_> = stream.try_collect().await?; + + assert_eq!(parts.len(), part_count); + for (i, bytes) in parts.iter().enumerate() { + assert_eq!(bytes.len(), part_len); + assert_eq!(bytes[..], data[i * part_len..i * part_len + part_len]); + } + } + } + Ok(()) + } + + #[tokio::test] + async fn partitions_with_remainder() -> AzureResult<()> { + for part_count in [2usize, 3, 11, 16] { + for part_len in [1024usize, 1000, 9999] { + for dangling_len in [part_len / 2, 100, 128, 99] { + let data = get_random_data(part_len * (part_count - 1) + dangling_len); + let stream = PartitionedStream::new( + Box::new(BytesStream::new(data.clone())), + NonZero::new(part_len).unwrap(), + ); + + let parts: Vec<_> = stream.try_collect().await?; + + assert_eq!(parts.len(), part_count); + for (i, bytes) in parts[..parts.len()].iter().enumerate() { + if i == parts.len() - 1 { + assert_eq!(bytes.len(), dangling_len); + assert_eq!(bytes[..], data[i * part_len..]); + } else { + assert_eq!(bytes.len(), part_len); + assert_eq!(bytes[..], data[i * part_len..i * part_len + part_len]); + } + } + } + } + } + Ok(()) + } + + #[tokio::test] + async fn exactly_one_partition() -> AzureResult<()> { + for len in [1024usize, 1000, 9999, 1] { + let data = get_random_data(len); + let mut stream = PartitionedStream::new( + Box::new(BytesStream::new(data.clone())), + NonZero::new(len).unwrap(), + ); + + let single_partition = stream.try_next().await?.unwrap(); + + assert_eq!(stream.try_next().await?, None); + assert_eq!(single_partition[..], data[..]); + } + Ok(()) + } + + #[tokio::test] + async fn less_than_one_partition() -> AzureResult<()> { + let part_len = 99999usize; + for len in [1024usize, 1000, 9999, 1] { + let data = get_random_data(len); + let mut stream = PartitionedStream::new( + Box::new(BytesStream::new(data.clone())), + NonZero::new(part_len).unwrap(), + ); + + let single_partition = stream.try_next().await?.unwrap(); + + assert!(stream.try_next().await?.is_none()); + assert_eq!(single_partition[..], data[..]); + } + Ok(()) + } + + #[tokio::test] + async fn successful_empty_stream_when_empty_source_stream() -> AzureResult<()> { + for part_len in [1024usize, 1000, 9999, 1] { + let data = get_random_data(0); + let mut stream = PartitionedStream::new( + Box::new(BytesStream::new(data.clone())), + NonZero::new(part_len).unwrap(), + ); + + assert!(stream.try_next().await?.is_none()); + } + Ok(()) + } +}