diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index a299e50a..b514348e 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -80,7 +80,6 @@ public class FanOutKinesisShardSubscription { // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2 // record batches available on next read. private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(2); - private final AtomicBoolean subscriptionActive = new AtomicBoolean(false); private final AtomicReference subscriptionException = new AtomicReference<>(); // Store the current starting position for this subscription. Will be updated each time new @@ -108,8 +107,9 @@ public void activateSubscription() { shardId, startingPosition, consumerArn); - if (subscriptionActive.get()) { - LOG.warn("Skipping activation of subscription since it is already active."); + if (shardSubscriber != null + && shardSubscriber.getSubscriptionState() == SubscriptionState.SUBSCRIBED) { + LOG.warn("Skipping activation of subscription since it is active & subscribed."); return; } @@ -166,9 +166,9 @@ public void activateSubscription() { shardId, startingPosition, consumerArn); - subscriptionActive.set(true); // Request first batch of records. shardSubscriber.requestRecords(); + } else { String errorMessage = "Timeout when subscribing to shard " @@ -236,16 +236,43 @@ public SubscribeToShardEvent nextEvent() { throw new KinesisStreamsSourceException( "Subscription encountered unrecoverable exception.", throwable); } + final SubscriptionState state = + Optional.ofNullable(shardSubscriber) + .map(FanOutShardSubscriber::getSubscriptionState) + .orElse(SubscriptionState.NOT_STARTED); - if (!subscriptionActive.get()) { - LOG.debug( - "Subscription to shard {} for consumer {} is not yet active. Skipping.", - shardId, - consumerArn); - return null; + switch (state) { + case NOT_STARTED: + if (LOG.isDebugEnabled()) { + LOG.debug( + "Subscription to shard {} for consumer {} is not yet active. Skipping.", + shardId, + consumerArn); + } + return null; + case COMPLETED: + if (shardSubscriber.isShardEndReached()) { + if (LOG.isInfoEnabled()) { + LOG.info( + "Subscription reached SHARD_END for shard {} for consumer {}.", + shardId, + consumerArn); + } + return null; + } + if (LOG.isInfoEnabled()) { + LOG.info( + "Subscription expired to shard {} for consumer {}. Restarting.", + shardId, + consumerArn); + } + activateSubscription(); + return null; + case SUBSCRIBED: + return eventQueue.poll(); + default: + throw new IllegalStateException("Unknown subscription state: " + state); } - - return eventQueue.poll(); } /** @@ -254,26 +281,48 @@ public SubscribeToShardEvent nextEvent() { */ private class FanOutShardSubscriber implements Subscriber { private final CountDownLatch subscriptionLatch; - private Subscription subscription; + private final AtomicReference subscriptionState = + new AtomicReference<>(SubscriptionState.NOT_STARTED); + private final AtomicBoolean isShardEnd = new AtomicBoolean(false); + private FanOutShardSubscriber(CountDownLatch subscriptionLatch) { this.subscriptionLatch = subscriptionLatch; } + /** + * Fetch the state that the subscriber is in. + * + * @return Subscription state for the subscriber. + */ + public SubscriptionState getSubscriptionState() { + return subscriptionState.get(); + } + + /** + * Boolean whether this subscriber has reached the end of a shard. + * + * @return True if ShardEnd. false otherwise. + */ + public boolean isShardEndReached() { + return isShardEnd.get(); + } + public void requestRecords() { subscription.request(1); } public void cancel() { - if (!subscriptionActive.get()) { - LOG.warn("Trying to cancel inactive subscription. Ignoring."); + if (this.subscriptionState.get() == SubscriptionState.COMPLETED) { + LOG.warn("Subscription is already completed. Ignoring request to cancel."); return; } - subscriptionActive.set(false); + if (subscription != null) { subscription.cancel(); } + this.subscriptionState.set(SubscriptionState.COMPLETED); } @Override @@ -284,6 +333,7 @@ public void onSubscribe(Subscription subscription) { startingPosition, consumerArn); this.subscription = subscription; + this.subscriptionState.set(SubscriptionState.SUBSCRIBED); subscriptionLatch.countDown(); } @@ -300,6 +350,15 @@ public void visit(SubscribeToShardEvent event) { event); eventQueue.put(event); + if (event.continuationSequenceNumber() == null) { + if (LOG.isDebugEnabled()) { + LOG.debug("continuationSequenceNumber is null. " + + "Reached ShardEnd for shard: {}", shardId); + } + isShardEnd.set(true); + return; + } + // Update the starting position in case we have to recreate the // subscription startingPosition = @@ -330,8 +389,14 @@ public void onError(Throwable throwable) { @Override public void onComplete() { LOG.info("Subscription complete - {} ({})", shardId, consumerArn); - cancel(); - activateSubscription(); + this.subscriptionState.set(SubscriptionState.COMPLETED); } } + + /** States that the {@code FanOutShardSubscriber} may be in. */ + private enum SubscriptionState { + NOT_STARTED, + SUBSCRIBED, + COMPLETED + } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java new file mode 100644 index 00000000..b6d66b78 --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionTest.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; +import org.apache.flink.connector.kinesis.source.util.FakeKinesisFanOutBehaviorsFactory; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponse; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.generateShardId; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link FanOutKinesisShardSubscription}. */ +class FanOutKinesisShardSubscriptionTest { + + private static final String TEST_SHARD_ID = generateShardId(1); + private static final Duration SUBSCRIPTION_TIMEOUT = Duration.ofSeconds(5); + + @Test + void testNextEventReturnsNullBeforeActivation() { + AsyncStreamProxy proxy = FakeKinesisFanOutBehaviorsFactory.boundedShard().build(); + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + SUBSCRIPTION_TIMEOUT); + + assertThat(subscription.nextEvent()).isNull(); + } + + @Test + void testResourceNotFoundExceptionThrown() { + AsyncStreamProxy proxy = + FakeKinesisFanOutBehaviorsFactory.resourceNotFoundWhenObtainingSubscription(); + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + SUBSCRIPTION_TIMEOUT); + + subscription.activateSubscription(); + + // Poll until exception surfaces + assertThatThrownBy( + () -> { + for (int i = 0; i < 200; i++) { + subscription.nextEvent(); + Thread.sleep(50); + } + }) + .isInstanceOf(ResourceNotFoundException.class); + } + + @Test + void testUnrecoverableExceptionWrappedInSourceException() throws Exception { + AsyncStreamProxy proxy = + new AsyncStreamProxy() { + @Override + public CompletableFuture subscribeToShard( + String consumerArn, + String shardId, + StartingPosition startingPosition, + SubscribeToShardResponseHandler responseHandler) { + responseHandler.exceptionOccurred( + new IllegalStateException("unrecoverable")); + return CompletableFuture.completedFuture(null); + } + + @Override + public void close() {} + }; + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + SUBSCRIPTION_TIMEOUT); + + subscription.activateSubscription(); + + assertThatThrownBy( + () -> { + for (int i = 0; i < 200; i++) { + subscription.nextEvent(); + Thread.sleep(50); + } + }) + .isInstanceOf(KinesisStreamsSourceException.class) + .hasMessageContaining("unrecoverable"); + } + + @Test + void testSubscriptionTimeoutTerminatesSubscription() throws Exception { + AsyncStreamProxy proxy = + new AsyncStreamProxy() { + @Override + public CompletableFuture subscribeToShard( + String consumerArn, + String shardId, + StartingPosition startingPosition, + SubscribeToShardResponseHandler responseHandler) { + return new CompletableFuture<>(); + } + + @Override + public void close() {} + }; + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + Duration.ofMillis(200)); + + subscription.activateSubscription(); + + // Wait for timeout to trigger, then poll - should recover + Thread.sleep(500); + SubscribeToShardEvent event = subscription.nextEvent(); + assertThat(event).isNull(); + } + + @Test + void testExpiredSubscriptionResubscribes() throws Exception { + AtomicInteger subscribeCount = new AtomicInteger(0); + AsyncStreamProxy proxy = + new AsyncStreamProxy() { + @Override + public CompletableFuture subscribeToShard( + String consumerArn, + String shardId, + StartingPosition startingPosition, + SubscribeToShardResponseHandler responseHandler) { + subscribeCount.incrementAndGet(); + return CompletableFuture.supplyAsync( + () -> { + responseHandler.responseReceived( + SubscribeToShardResponse.builder().build()); + responseHandler.onEventStream( + subscriber -> { + subscriber.onSubscribe( + new Subscription() { + @Override + public void request(long n) { + // Complete without sending any + // events (simulates 5-min expiry) + subscriber.onComplete(); + } + + @Override + public void cancel() {} + }); + }); + return null; + }); + } + + @Override + public void close() {} + }; + + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + SUBSCRIPTION_TIMEOUT); + + subscription.activateSubscription(); + Thread.sleep(500); + + // nextEvent() should detect COMPLETED without shard-end and trigger resubscription + subscription.nextEvent(); + Thread.sleep(500); + + assertThat(subscribeCount.get()).isEqualTo(2); + } + + @Test + void testShardEndDoesNotResubscribe() throws Exception { + AtomicInteger subscribeCount = new AtomicInteger(0); + AsyncStreamProxy proxy = + new AsyncStreamProxy() { + @Override + public CompletableFuture subscribeToShard( + String consumerArn, + String shardId, + StartingPosition startingPosition, + SubscribeToShardResponseHandler responseHandler) { + subscribeCount.incrementAndGet(); + return CompletableFuture.supplyAsync( + () -> { + responseHandler.responseReceived( + SubscribeToShardResponse.builder().build()); + responseHandler.onEventStream( + subscriber -> { + subscriber.onSubscribe( + new Subscription() { + private boolean sent = false; + + @Override + public void request(long n) { + if (!sent) { + sent = true; + // Send event with null + // continuation (shard end) + subscriber.onNext( + SubscribeToShardEvent + .builder() + .millisBehindLatest( + 0L) + .continuationSequenceNumber( + null) + .build()); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + return null; + }); + } + + @Override + public void close() {} + }; + + FanOutKinesisShardSubscription subscription = + new FanOutKinesisShardSubscription( + proxy, + CONSUMER_ARN, + TEST_SHARD_ID, + StartingPosition.fromStart(), + SUBSCRIPTION_TIMEOUT); + + subscription.activateSubscription(); + Thread.sleep(500); + + // Drain the shard-end event from the queue + subscription.nextEvent(); + Thread.sleep(500); + + // Should not have resubscribed — shard has ended + assertThat(subscribeCount.get()).isEqualTo(1); + assertThat(subscription.nextEvent()).isNull(); + } +}