|
| 1 | +use std::{cmp::max, future::Future, num::NonZero, pin::Pin}; |
| 2 | + |
| 3 | +use futures::{ |
| 4 | + future::{self}, |
| 5 | + Stream, TryStreamExt, |
| 6 | +}; |
| 7 | + |
| 8 | +type AzureResult<T> = azure_core::Result<T>; |
| 9 | + |
| 10 | +/// Executes async operations from a queue with a concurrency limit. |
| 11 | +/// |
| 12 | +/// This function consumes a stream (`ops_queue`) of async operation factories (closures returning futures), |
| 13 | +/// and runs up to `parallel` operations concurrently. As operations complete, new ones are started from the queue, |
| 14 | +/// maintaining the concurrency limit. If any operation or queue item returns an error, the function returns early |
| 15 | +/// with that error. When all operations and queue items are complete, returns `Ok(())`. |
| 16 | +/// |
| 17 | +/// # Parameters |
| 18 | +/// - `ops_queue`: A stream yielding `Result<FnOnce() -> TFut, TErr>`. Each item is either a closure producing a future, |
| 19 | +/// or an error. The stream must be `Unpin`. |
| 20 | +/// - `parallel`: The maximum number of operations to run concurrently. Must be non-zero. |
| 21 | +/// |
| 22 | +/// # Behavior |
| 23 | +/// - Operations are scheduled as soon as possible, up to the concurrency limit. |
| 24 | +/// - If an error is encountered in the queue or in any operation, the function returns that error immediately. |
| 25 | +/// - When the queue is exhausted, waits for all running operations to complete before returning. |
| 26 | +/// |
| 27 | +/// # Errors |
| 28 | +/// Returns the first error encountered from the queue or any operation. |
| 29 | +/// |
| 30 | +/// # Type Parameters |
| 31 | +/// - `TFut`: Future type returned by each operation. |
| 32 | +/// - `TErr`: Error type for queue or operation failures. |
| 33 | +async fn run_all_with_concurrency_limit<Fut, Err>( |
| 34 | + mut ops_queue: impl Stream<Item = Result<impl FnOnce() -> Fut, Err>> + Unpin, |
| 35 | + parallel: NonZero<usize>, |
| 36 | +) -> Result<(), Err> |
| 37 | +where |
| 38 | + Fut: Future<Output = Result<(), Err>>, |
| 39 | +{ |
| 40 | + let parallel = parallel.get(); |
| 41 | + |
| 42 | + // if no real parallelism, take the simple option of executing ops sequentially. |
| 43 | + // The "true" implementation can't handle parallel < 2. |
| 44 | + if parallel == 1 { |
| 45 | + while let Some(op) = ops_queue.try_next().await? { |
| 46 | + op().await?; |
| 47 | + } |
| 48 | + return Ok(()); |
| 49 | + } |
| 50 | + |
| 51 | + let first_op = match ops_queue.try_next().await? { |
| 52 | + Some(item) => item, |
| 53 | + None => return Ok(()), |
| 54 | + }; |
| 55 | + |
| 56 | + let mut get_next_completed_op_future = future::select_all(vec![Box::pin(first_op())]); |
| 57 | + let mut get_next_queue_op_future = ops_queue.try_next(); |
| 58 | + loop { |
| 59 | + // while max parallel running ops, focus on just running ops |
| 60 | + get_next_completed_op_future = run_down(get_next_completed_op_future, parallel - 1).await?; |
| 61 | + |
| 62 | + match future::select(get_next_queue_op_future, get_next_completed_op_future).await { |
| 63 | + future::Either::Left((Err(e), _)) => return Err(e), |
| 64 | + future::Either::Right(((Err(e), _, _), _)) => return Err(e), |
| 65 | + |
| 66 | + // Next op in the queue arrived first. Add it to existing running ops. |
| 67 | + future::Either::Left((Ok(next_op_in_queue), running_ops_fut)) => { |
| 68 | + get_next_queue_op_future = ops_queue.try_next(); |
| 69 | + get_next_completed_op_future = running_ops_fut; |
| 70 | + |
| 71 | + match next_op_in_queue { |
| 72 | + Some(op) => { |
| 73 | + get_next_completed_op_future = |
| 74 | + combine_select_all(get_next_completed_op_future, Box::pin(op())); |
| 75 | + } |
| 76 | + // queue was finished, race is over |
| 77 | + None => break, |
| 78 | + } |
| 79 | + } |
| 80 | + // A running op completed first. Start another select_all with remaining running ops. |
| 81 | + future::Either::Right(((Ok(_), _, remaining_running_ops), next_op_fut)) => { |
| 82 | + // remaining_running_ops could be empty now. |
| 83 | + // select panics on empty iter, so we can't race in this case. |
| 84 | + // forcibly wait for next op in queue and handle it before continuing. |
| 85 | + if remaining_running_ops.is_empty() { |
| 86 | + let next_op = match next_op_fut.await? { |
| 87 | + Some(item) => item, |
| 88 | + None => return Ok(()), |
| 89 | + }; |
| 90 | + get_next_queue_op_future = ops_queue.try_next(); |
| 91 | + get_next_completed_op_future = future::select_all(vec![Box::pin(next_op())]); |
| 92 | + } else { |
| 93 | + get_next_queue_op_future = next_op_fut; |
| 94 | + get_next_completed_op_future = future::select_all(remaining_running_ops); |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + let _ = future::try_join_all(get_next_completed_op_future.into_inner()).await?; |
| 101 | + Ok(()) |
| 102 | +} |
| 103 | + |
| 104 | +/// Loops `future::select_all()` with the existing `SelectAll`` until the target remaining |
| 105 | +/// inner futures is reached. Will always leave at least one inner future remaining, for |
| 106 | +/// type simplicity (select_all panics on len == 0); |
| 107 | +async fn run_down<Fut, Err>( |
| 108 | + select_fut: future::SelectAll<Pin<Box<Fut>>>, |
| 109 | + target_remaining: usize, |
| 110 | +) -> Result<future::SelectAll<Pin<Box<Fut>>>, Err> |
| 111 | +where |
| 112 | + Fut: Future<Output = Result<(), Err>>, |
| 113 | +{ |
| 114 | + let target_remaining = max(target_remaining, 1); |
| 115 | + let mut select_vec = select_fut.into_inner(); |
| 116 | + while select_vec.len() > target_remaining { |
| 117 | + let result; |
| 118 | + (result, _, select_vec) = future::select_all(select_vec).await; |
| 119 | + result?; |
| 120 | + } |
| 121 | + Ok(future::select_all(select_vec)) |
| 122 | +} |
| 123 | + |
| 124 | +/// Adds a pin-boxed future to an existing SelectAll of pin-boxed futures. |
| 125 | +fn combine_select_all<Fut>( |
| 126 | + select_fut: future::SelectAll<Pin<Box<Fut>>>, |
| 127 | + new_fut: Pin<Box<Fut>>, |
| 128 | +) -> future::SelectAll<Pin<Box<Fut>>> |
| 129 | +where |
| 130 | + Fut: Future, |
| 131 | +{ |
| 132 | + let mut futures = select_fut.into_inner(); |
| 133 | + futures.push(new_fut); |
| 134 | + future::select_all(futures) |
| 135 | +} |
| 136 | + |
| 137 | +#[cfg(test)] |
| 138 | +mod tests { |
| 139 | + use futures::{ready, FutureExt}; |
| 140 | + |
| 141 | + use super::*; |
| 142 | + use std::{pin::Pin, sync::mpsc::channel, task::Poll, time::Duration}; |
| 143 | + |
| 144 | + #[tokio::test] |
| 145 | + async fn enforce_concurrency_limit() -> AzureResult<()> { |
| 146 | + let parallel = 4usize; |
| 147 | + let num_ops = parallel + 1; |
| 148 | + let wait_time_millis = 10u64; |
| 149 | + let op_time_millis = wait_time_millis + 50; |
| 150 | + |
| 151 | + let (sender, receiver) = channel(); |
| 152 | + |
| 153 | + // setup a series of operations that send a unique number to a channel |
| 154 | + // we can then assert the expected numbers made it to the channel at expected times |
| 155 | + let ops = (0..num_ops).map(|i| { |
| 156 | + let s = sender.clone(); |
| 157 | + Ok(async move || { |
| 158 | + s.send(i).unwrap(); |
| 159 | + tokio::time::sleep(Duration::from_millis(op_time_millis)).await; |
| 160 | + AzureResult::<()>::Ok(()) |
| 161 | + }) |
| 162 | + }); |
| 163 | + |
| 164 | + let race = future::select( |
| 165 | + Box::pin(run_all_with_concurrency_limit( |
| 166 | + futures::stream::iter(ops), |
| 167 | + NonZero::new(parallel).unwrap(), |
| 168 | + )), |
| 169 | + Box::pin(tokio::time::sleep(Duration::from_millis(wait_time_millis))), |
| 170 | + ) |
| 171 | + .await; |
| 172 | + match race { |
| 173 | + future::Either::Left(_) => panic!("Wrong future won the race."), |
| 174 | + future::Either::Right((_, run_all_fut)) => { |
| 175 | + let mut items: Vec<_> = receiver.try_iter().collect(); |
| 176 | + items.sort(); |
| 177 | + assert_eq!(items, (0..parallel).collect::<Vec<_>>()); |
| 178 | + |
| 179 | + run_all_fut.await?; |
| 180 | + assert_eq!(receiver.try_iter().collect::<Vec<_>>().len(), 1); |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + Ok(()) |
| 185 | + } |
| 186 | + |
| 187 | + #[tokio::test] |
| 188 | + async fn handles_slow_stream() -> AzureResult<()> { |
| 189 | + let parallel = 10; |
| 190 | + let num_ops = 5; |
| 191 | + let op_time_millis = 10; |
| 192 | + let stream_time_millis = op_time_millis + 10; |
| 193 | + // setup a series of operations that send a unique number to a channel |
| 194 | + // we can then assert the expected numbers made it to the channel at expected times |
| 195 | + let ops = (0..num_ops).map(|_| { |
| 196 | + Ok(async move || { |
| 197 | + tokio::time::sleep(Duration::from_millis(op_time_millis)).await; |
| 198 | + AzureResult::<()>::Ok(()) |
| 199 | + }) |
| 200 | + }); |
| 201 | + |
| 202 | + run_all_with_concurrency_limit( |
| 203 | + SlowStream::new(ops, Duration::from_millis(stream_time_millis)), |
| 204 | + NonZero::new(parallel).unwrap(), |
| 205 | + ) |
| 206 | + .await |
| 207 | + } |
| 208 | + |
| 209 | + #[tokio::test] |
| 210 | + async fn success_when_no_ops() -> AzureResult<()> { |
| 211 | + let parallel = 4usize; |
| 212 | + |
| 213 | + // not possible to manually type what we need |
| 214 | + // make a vec with a concrete element and then remove it to get the desired typing |
| 215 | + let op = || future::ready::<Result<(), azure_core::Error>>(Ok(())); |
| 216 | + let mut ops = vec![Ok(op)]; |
| 217 | + ops.pop(); |
| 218 | + |
| 219 | + run_all_with_concurrency_limit(futures::stream::iter(ops), NonZero::new(parallel).unwrap()) |
| 220 | + .await |
| 221 | + } |
| 222 | + |
| 223 | + struct SlowStream<Iter> { |
| 224 | + sleep: Pin<Box<tokio::time::Sleep>>, |
| 225 | + interval: Duration, |
| 226 | + iter: Iter, |
| 227 | + } |
| 228 | + impl<Iter> SlowStream<Iter> { |
| 229 | + fn new(iter: Iter, interval: Duration) -> Self { |
| 230 | + Self { |
| 231 | + sleep: Box::pin(tokio::time::sleep(interval)), |
| 232 | + interval, |
| 233 | + iter, |
| 234 | + } |
| 235 | + } |
| 236 | + } |
| 237 | + impl<Iter: Iterator + Unpin> Stream for SlowStream<Iter> { |
| 238 | + type Item = Iter::Item; |
| 239 | + |
| 240 | + fn poll_next( |
| 241 | + self: std::pin::Pin<&mut Self>, |
| 242 | + cx: &mut std::task::Context<'_>, |
| 243 | + ) -> std::task::Poll<Option<Self::Item>> { |
| 244 | + let this = self.get_mut(); |
| 245 | + ready!(this.sleep.poll_unpin(cx)); |
| 246 | + this.sleep = Box::pin(tokio::time::sleep(this.interval)); |
| 247 | + Poll::Ready(this.iter.next()) |
| 248 | + } |
| 249 | + } |
| 250 | +} |
0 commit comments