diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/CancellationStrategySpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/CancellationStrategySpec.scala index 05419c522c9..cf6241e100a 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/CancellationStrategySpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/CancellationStrategySpec.scala @@ -108,6 +108,40 @@ class CancellationStrategySpec extends StreamSpec("""pekko.loglevel = DEBUG out2Probe.expectError(SubscriptionWithCancelException.NoMoreElementsNeeded) } } + "BidirectionalGracefulShutdown" should { + "complete outputs first then cancel inputs after delay" in new TestSetup( + CancellationStrategy.BidirectionalGracefulShutdown(500.millis)) { + out1Probe.cancel() + // outputs should be completed immediately + out2Probe.expectComplete() + // inputs should be cancelled after delay + inProbe.expectNoMessage(200.millis) + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + "propagate failure to outputs first then cancel inputs after delay" in new TestSetup( + CancellationStrategy.BidirectionalGracefulShutdown(500.millis)) { + val theError = TE("Test error") + out1Probe.cancel(theError) + // outputs should fail immediately with the error + out2Probe.expectError(theError) + // inputs should be cancelled after delay + inProbe.expectNoMessage(200.millis) + inProbe.expectCancellationWithCause(theError) + } + "prevent further elements from coming through during grace period" in new TestSetup( + CancellationStrategy.BidirectionalGracefulShutdown(500.millis)) { + out1Probe.request(1) + out2Probe.request(1) + out1Probe.cancel() + // outputs should be completed immediately + out2Probe.expectComplete() + // inputs should not receive elements during grace period + inProbe.sendNext(B(123)) + inProbe.expectNoMessage(200.millis) + // after delay inputs should be cancelled + inProbe.expectCancellationWithCause(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + } } "cancellation races with BidiStacks" should { @@ -135,6 +169,13 @@ class CancellationStrategySpec extends StreamSpec("""pekko.loglevel = DEBUG toStream.expectCancellationWithCause(theError) fromStream.expectError(theError) } + "be prevented by BidirectionalGracefulShutdown strategy" in new RaceTestSetup( + CancellationStrategy.BidirectionalGracefulShutdown(500.millis.dilated)) { + val theError = TE("Duck meowed") + killSwitch.abort(theError) + toStream.expectCancellationWithCause(theError) + fromStream.expectError(theError) + } class RaceTestSetup(cancellationStrategy: CancellationStrategy.Strategy) { val toStream = TestPublisher.probe[A]() diff --git a/stream/src/main/scala/org/apache/pekko/stream/Attributes.scala b/stream/src/main/scala/org/apache/pekko/stream/Attributes.scala index 4f6d346c314..3b824e91f53 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/Attributes.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/Attributes.scala @@ -565,6 +565,44 @@ object Attributes { */ @ApiMayChange def afterDelay(delay: java.time.Duration, strategy: Strategy): Strategy = AfterDelay(delay.toScala, strategy) + + /** + * Strategy that ensures graceful shutdown for bidirectional components. + * + * When `cancelStage` is invoked, this strategy first completes all output ports (regularly or with an error), + * then waits for a grace period to allow the completion/error signal to propagate through the counterpart, + * and finally cancels all input ports. + * + * This addresses the race condition in bidirectional components where cancelling the upstream side might + * prevent the error from being properly propagated downstream. By completing outputs first and waiting, + * the error has a chance to bubble through the counterpart before the upstream is cancelled. + * + * This strategy is particularly useful in stacks of BidiFlows where different layers are connected + * through both inputs and outputs, and error propagation is important for proper diagnostics. + * + * @param delay the grace period to wait after completing outputs before cancelling inputs + */ + @ApiMayChange + final case class BidirectionalGracefulShutdown(delay: FiniteDuration) extends Strategy + + /** + * Java API + * + * Strategy that ensures graceful shutdown for bidirectional components. + * + * When `cancelStage` is invoked, this strategy first completes all output ports (regularly or with an error), + * then waits for a grace period to allow the completion/error signal to propagate through the counterpart, + * and finally cancels all input ports. + * + * This addresses the race condition in bidirectional components where cancelling the upstream side might + * prevent the error from being properly propagated downstream. By completing outputs first and waiting, + * the error has a chance to bubble through the counterpart before the upstream is cancelled. + * + * @param delay the grace period to wait after completing outputs before cancelling inputs + */ + @ApiMayChange + def bidirectionalGracefulShutdown(delay: java.time.Duration): Strategy = + BidirectionalGracefulShutdown(delay.toScala) } /** @@ -629,6 +667,26 @@ object Attributes { strategy: CancellationStrategy.Strategy): CancellationStrategy.Strategy = CancellationStrategy.AfterDelay(delay, strategy) + /** + * Java API + * + * Strategy that ensures graceful shutdown for bidirectional components. + * + * When `cancelStage` is invoked, this strategy first completes all output ports (regularly or with an error), + * then waits for a grace period to allow the completion/error signal to propagate through the counterpart, + * and finally cancels all input ports. + * + * This addresses the race condition in bidirectional components where cancelling the upstream side might + * prevent the error from being properly propagated downstream. By completing outputs first and waiting, + * the error has a chance to bubble through the counterpart before the upstream is cancelled. + * + * @param delay the grace period to wait after completing outputs before cancelling inputs + */ + @ApiMayChange + def cancellationStrategyBidirectionalGracefulShutdown( + delay: FiniteDuration): CancellationStrategy.Strategy = + CancellationStrategy.BidirectionalGracefulShutdown(delay) + /** * Nested materialization cancellation strategy provides a way to configure the cancellation behavior of stages that materialize a nested flow. * diff --git a/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala b/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala index 5e1aec586c6..90587b7b46d 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/stage/GraphStage.scala @@ -21,7 +21,7 @@ import scala.annotation.nowarn import scala.annotation.tailrec import scala.collection.{ immutable, mutable } import scala.concurrent.{ Future, Promise } -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{ Duration, FiniteDuration } import org.apache.pekko import pekko.{ Done, NotUsed } @@ -718,9 +718,83 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: case AfterDelay(_, andThen) => // delay handled at the stage that sends the delay. See `def cancel(in, cause)`. internalCancelStage(cause, andThen) + case BidirectionalGracefulShutdown(delay) => + // For bidirectional graceful shutdown: + // 1. First complete all output ports (regularly or with error) + // 2. Wait for grace period to allow error to propagate through counterpart + // 3. Then cancel all input ports + internalBidirectionalGracefulShutdown(cause, delay) } } + /** + * Implements bidirectional graceful shutdown for bidirectional components. + * + * This method first completes all output ports, then waits for a grace period + * to allow the completion/error signal to propagate through the counterpart, + * and finally cancels all input ports. + * + * This addresses the race condition in bidirectional components where cancelling + * the upstream side might prevent the error from being properly propagated downstream. + */ + private def internalBidirectionalGracefulShutdown(cause: Throwable, delay: FiniteDuration): Unit = { + import SubscriptionWithCancelException._ + + // Determine if this should be a failure or regular completion based on the cause + val isFailure = cause match { + case NoMoreElementsNeeded | StageWasCompleted => false + case _ => true + } + + // Step 1: Complete all output ports first + var i = inCount // Start from output ports (after input ports in portToConn) + while (i < portToConn.length) { + if (isFailure) + interpreter.fail(portToConn(i), cause) + else + handlers(i) match { + case e: Emitting[Any @unchecked] => e.addFollowUp(new EmittingCompletion[Any](e.out, e.previous)) + case _ => interpreter.complete(portToConn(i)) + } + i += 1 + } + + // Step 2: Schedule cancellation of input ports after delay + if (delay == Duration.Zero) { + // If delay is zero, cancel immediately + cancelAllInputPorts(cause) + } else { + // Install handlers to ignore incoming elements on input ports during grace period + var j = 0 + while (j < inCount) { + val connection = portToConn(j) + connection.inHandler = EagerTerminateInput + j += 1 + } + + // Schedule the actual cancellation after the delay + val callback = getAsyncCallback[Throwable] { cancelCause => + cancelAllInputPorts(cancelCause) + } + materializer.scheduleOnce(delay, () => callback.invoke(cause)) + } + + cleanUpSubstreams(if (isFailure) OptionVal.Some(cause) else OptionVal.None) + setKeepGoing(true) // Keep stage alive during grace period + } + + /** + * Cancels all input ports with the given cause. + */ + private def cancelAllInputPorts(cause: Throwable): Unit = { + var i = 0 + while (i < inCount) { + doCancel(portToConn(i), cause) + i += 1 + } + setKeepGoing(false) // Allow stage to complete after cancellation + } + /** * Automatically invokes [[cancel]] or [[fail]] on all the input or output ports that have been called, * then marks the operator as stopped.