diff --git a/core/shared/src/main/scala/fs2/concurrent/Topic.scala b/core/shared/src/main/scala/fs2/concurrent/Topic.scala index b069bcfe6f..0b17437adb 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Topic.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Topic.scala @@ -152,20 +152,47 @@ object Topic { ( F.ref(State.initial[F, A]), SignallingRef[F, Int](0), + F.deferred[Unit], F.deferred[Unit] - ).mapN { case (state, subscriberCount, signalClosure) => + ).mapN { case (state, subscriberCount, signalClosure, publishersFinished) => new Topic[F, A] { def foreach[B](lm: LongMap[B])(f: B => F[Unit]) = - lm.foldLeft(F.unit) { case (op, (_, b)) => op >> f(b) } + lm.foldLeft(F.unit) { case (op, (_, b)) => f(b) >> op } def publish1(a: A): F[Either[Topic.Closed, Unit]] = - state.get.flatMap { - case State.Closed() => - Topic.closed.pure[F] - case State.Active(subs, _) => - foreach(subs)(_.send(a).void) - .as(Topic.rightUnit) + state.flatModify { + case s @ State.Active(subs, _, n, false) => + val inc = n + 1 + val newState = s.copy(publishing = inc) + + val sends = subs.foldLeft(F.pure(true)) { case (acc, (_, chan)) => + chan.send(a).map(_.isRight).map2(acc)(_ && _) + } + + val action = sends.flatMap { allSucceeded => + state + .flatModify { + case s @ State.Active(subs, _, n, closing) => + val dec = n - 1 + if (dec == 0 && closing) { + val closeAction = foreach(subs)(_.close.void) + (State.Closed(), closeAction >> publishersFinished.complete(()).void) + } else { + (s.copy(publishing = dec), F.unit) + } + case s @ State.Closed() => (s, F.unit) + } + .map { _ => + if (allSucceeded) Topic.rightUnit else Topic.closed + } + } + (newState, action) + + case s @ State.Active(_, _, _, true) => + (s, Topic.closed.pure[F]) + case s @ State.Closed() => + (s, Topic.closed.pure[F]) } def subscribeAwait(maxQueued: Int): Resource[F, Stream[F, A]] = @@ -181,18 +208,20 @@ object Topic { def subscribeAwaitImpl(chan: Channel[F, A]): Resource[F, Stream[F, A]] = { val subscribe: F[Option[Long]] = state.flatModify { - case State.Active(subs, nextId) => - val newState = State.Active(subs.updated(nextId, chan), nextId + 1) + case s @ State.Active(subs, nextId, _, false) => + val newState = s.copy(subscribers = subs.updated(nextId, chan), nextId = nextId + 1) val action = subscriberCount.update(_ + 1) val result = Some(nextId) newState -> action.as(result) + case s @ State.Active(_, _, _, true) => + s -> F.pure(None) case closed @ State.Closed() => closed -> F.pure(None) } def unsubscribe(id: Long): F[Unit] = state.flatModify { - case State.Active(subs, nextId) => + case s @ State.Active(subs, _, _, _) => // _After_ we remove the bounded channel for this // subscriber, we need to drain it to unblock to // publish loop which might have already enqueued @@ -202,7 +231,7 @@ object Topic { chan.close >> chan.stream.compile.drain } - State.Active(subs - id, nextId) -> (drainChannel *> subscriberCount.update(_ - 1)) + s.copy(subscribers = subs - id) -> (drainChannel *> subscriberCount.update(_ - 1)) case closed @ State.Closed() => closed -> F.unit @@ -236,9 +265,15 @@ object Topic { def close: F[Either[Topic.Closed, Unit]] = state.flatModify { - case State.Active(subs, _) => - val action = foreach(subs)(_.close.void) *> signalClosure.complete(()) - (State.Closed(), action.as(Topic.rightUnit)) + case s @ State.Active(subs, _, n, false) => + if (n == 0) { + val action = foreach(subs)(_.close.void) *> signalClosure.complete(()) + (State.Closed(), (action >> publishersFinished.complete(())).as(Topic.rightUnit)) + } else { + (s.copy(closing = true), publishersFinished.get.as(Topic.rightUnit)) + } + case s @ State.Active(_, _, _, true) => + (s, publishersFinished.get.as(Topic.rightUnit)) case closed @ State.Closed() => (closed, Topic.closed.pure[F]) } @@ -253,13 +288,15 @@ object Topic { private object State { case class Active[F[_], A]( subscribers: LongMap[Channel[F, A]], - nextId: Long + nextId: Long, + publishing: Long, + closing: Boolean ) extends State[F, A] case class Closed[F[_], A]() extends State[F, A] def initial[F[_], A]: State[F, A] = - Active(LongMap.empty, 1L) + Active(LongMap.empty, 1L, 0L, false) } private final val closed: Either[Closed, Unit] = Left(Closed) diff --git a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala index f2d889b92d..c037c32277 100644 --- a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala @@ -218,7 +218,7 @@ class TopicSuite extends Fs2Suite { // https://github.com/typelevel/fs2/issues/3644 test( - "when publish1 returns success, subscribers must receive the event, even if the publish1 races with close".fail + "when publish1 returns success, subscribers must receive the event, even if the publish1 races with close" ) { val check: IO[Unit] = Topic[IO, String]