diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java
index c0aefee5..d1af8b96 100644
--- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java
@@ -19,6 +19,7 @@
package org.apache.flink.connector.kinesis.source.reader.fanout;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.connector.base.source.reader.splitreader.SplitsChange;
import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
@@ -26,14 +27,27 @@
import org.apache.flink.connector.kinesis.source.reader.KinesisShardSplitReaderBase;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplitState;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+import org.apache.flink.util.concurrent.ExecutorThreadFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+import java.io.IOException;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT;
+import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_DEREGISTER_CONSUMER_TIMEOUT;
/**
* An implementation of the KinesisShardSplitReader that consumes from Kinesis using Enhanced
@@ -41,21 +55,206 @@
*/
@Internal
public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase {
+
+ private static final Logger LOG = LoggerFactory.getLogger(FanOutKinesisShardSplitReader.class);
private final AsyncStreamProxy asyncStreamProxy;
private final String consumerArn;
private final Duration subscriptionTimeout;
+ private final Duration deregisterTimeout;
+
+ /**
+ * Shared executor service for all shard subscriptions.
+ *
+ *
This executor uses an unbounded queue ({@link LinkedBlockingQueue}) but this does not pose
+ * a risk of out-of-memory errors due to the natural flow control mechanisms in the system:
+ *
+ *
+ * - Each {@link FanOutKinesisShardSubscription} has a bounded event queue with capacity of 2
+ * - New records are only requested after processing an event (via {@code requestRecords()})
+ * - When a shard's queue is full, the processing thread blocks at the {@code put()} operation
+ * - The AWS SDK implements the Reactive Streams protocol with built-in backpressure
+ *
+ *
+ * In the worst-case scenario during backpressure, the maximum number of events in memory is:
+ *
+ * Max Events = (2 * Number_of_Shards) + min(Number_of_Shards, Number_of_Threads)
+ *
+ *
+ * Where:
+ *
+ * - 2 * Number_of_Shards: Total capacity of all shard event queues
+ * - min(Number_of_Shards, Number_of_Threads): Maximum events being actively processed
+ *
+ *
+ * This ensures that memory usage scales linearly with the number of shards, not exponentially,
+ * making it safe to use an unbounded executor queue even with a large number of shards.
+ */
+ private final ExecutorService sharedShardSubscriptionExecutor;
+
+ /**
+ * Shared executor service for making subscribeToShard API calls.
+ *
+ *
This executor is separate from the event processing executor to avoid contention
+ * between API calls and event processing. Using a dedicated executor for subscription calls
+ * provides several important benefits:
+ *
+ *
+ * - Prevents blocking of the main thread or event processing threads during API calls
+ * - Isolates API call failures from event processing operations
+ * - Allows for controlled concurrency of API calls across multiple shards
+ * - Prevents potential deadlocks that could occur when the same thread handles both
+ * subscription calls and event processing
+ *
+ *
+ * The executor uses a smaller number of threads than the event processing executor since
+ * subscription calls are less frequent and primarily I/O bound. This helps optimize resource
+ * usage while still providing sufficient parallelism for multiple concurrent subscription calls.
+ */
+ private final ExecutorService sharedSubscriptionCallExecutor;
private final Map splitSubscriptions = new HashMap<>();
+ /**
+ * Factory for creating subscriptions. This is primarily used for testing.
+ */
+ @VisibleForTesting
+ public interface SubscriptionFactory {
+ FanOutKinesisShardSubscription createSubscription(
+ AsyncStreamProxy proxy,
+ String consumerArn,
+ String shardId,
+ StartingPosition startingPosition,
+ Duration timeout,
+ ExecutorService eventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor);
+ }
+
+ /**
+ * Default implementation of the subscription factory.
+ */
+ private static class DefaultSubscriptionFactory implements SubscriptionFactory {
+ @Override
+ public FanOutKinesisShardSubscription createSubscription(
+ AsyncStreamProxy proxy,
+ String consumerArn,
+ String shardId,
+ StartingPosition startingPosition,
+ Duration timeout,
+ ExecutorService eventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor) {
+ return new FanOutKinesisShardSubscription(
+ proxy,
+ consumerArn,
+ shardId,
+ startingPosition,
+ timeout,
+ eventProcessingExecutor,
+ subscriptionCallExecutor);
+ }
+ }
+
+ private SubscriptionFactory subscriptionFactory;
+
public FanOutKinesisShardSplitReader(
AsyncStreamProxy asyncStreamProxy,
String consumerArn,
Map shardMetricGroupMap,
Configuration configuration) {
+ this(asyncStreamProxy, consumerArn, shardMetricGroupMap, configuration, new DefaultSubscriptionFactory());
+ }
+
+ @VisibleForTesting
+ FanOutKinesisShardSplitReader(
+ AsyncStreamProxy asyncStreamProxy,
+ String consumerArn,
+ Map shardMetricGroupMap,
+ Configuration configuration,
+ SubscriptionFactory subscriptionFactory) {
+ this(
+ asyncStreamProxy,
+ consumerArn,
+ shardMetricGroupMap,
+ configuration,
+ subscriptionFactory,
+ createDefaultEventProcessingExecutor(),
+ createDefaultSubscriptionCallExecutor());
+ }
+
+ /**
+ * Constructor with injected executor services for testing.
+ *
+ * @param asyncStreamProxy The proxy for Kinesis API calls
+ * @param consumerArn The ARN of the consumer
+ * @param shardMetricGroupMap The metrics map
+ * @param configuration The configuration
+ * @param subscriptionFactory The factory for creating subscriptions
+ * @param eventProcessingExecutor The executor service to use for event processing tasks
+ * @param subscriptionCallExecutor The executor service to use for subscription API calls
+ */
+ @VisibleForTesting
+ FanOutKinesisShardSplitReader(
+ AsyncStreamProxy asyncStreamProxy,
+ String consumerArn,
+ Map shardMetricGroupMap,
+ Configuration configuration,
+ SubscriptionFactory subscriptionFactory,
+ ExecutorService eventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor) {
super(shardMetricGroupMap, configuration);
this.asyncStreamProxy = asyncStreamProxy;
this.consumerArn = consumerArn;
this.subscriptionTimeout = configuration.get(EFO_CONSUMER_SUBSCRIPTION_TIMEOUT);
+ this.deregisterTimeout = configuration.get(EFO_DEREGISTER_CONSUMER_TIMEOUT);
+ this.subscriptionFactory = subscriptionFactory;
+ this.sharedShardSubscriptionExecutor = eventProcessingExecutor;
+ this.sharedSubscriptionCallExecutor = subscriptionCallExecutor;
+ }
+
+ /**
+ * Creates the default executor service for event processing tasks.
+ *
+ * @return A new executor service
+ */
+ private static ExecutorService createDefaultEventProcessingExecutor() {
+ int minThreads = Runtime.getRuntime().availableProcessors();
+ int maxThreads = minThreads * 2;
+ return new ThreadPoolExecutor(
+ minThreads,
+ maxThreads,
+ 60L, TimeUnit.SECONDS,
+ new LinkedBlockingQueue<>(), // Unbounded queue with natural flow control
+ new ExecutorThreadFactory("kinesis-efo-subscription"));
+ }
+
+ /**
+ * Creates the default executor service for subscription API calls.
+ *
+ * This executor is configured with:
+ *
+ * - Minimum threads: 1 - Ensures at least one thread is always available for API calls
+ * - Maximum threads: max(2, availableProcessors/4) - Scales with system resources but
+ * keeps the thread count relatively low since API calls are I/O bound
+ * - Keep-alive time: 60 seconds - Allows for efficient reuse of threads
+ * - Unbounded queue - Safe because the number of subscription tasks is naturally limited
+ * by the number of shards
+ * - Custom thread factory - Provides meaningful thread names for debugging
+ *
+ *
+ * This configuration balances resource efficiency with responsiveness for subscription calls.
+ * Since subscription calls are primarily waiting on network I/O, a relatively small number of
+ * threads can efficiently handle many concurrent calls.
+ *
+ * @return A new executor service optimized for subscription API calls
+ */
+ private static ExecutorService createDefaultSubscriptionCallExecutor() {
+ int minThreads = 1;
+ int maxThreads = Math.max(2, Runtime.getRuntime().availableProcessors() / 4);
+ return new ThreadPoolExecutor(
+ minThreads,
+ maxThreads,
+ 60L, TimeUnit.SECONDS,
+ new LinkedBlockingQueue<>(), // Unbounded queue with natural flow control
+ new ExecutorThreadFactory("kinesis-subscription-caller"));
}
@Override
@@ -80,19 +279,167 @@ public void handleSplitsChanges(SplitsChange splitsChanges) {
super.handleSplitsChanges(splitsChanges);
for (KinesisShardSplit split : splitsChanges.splits()) {
FanOutKinesisShardSubscription subscription =
- new FanOutKinesisShardSubscription(
+ subscriptionFactory.createSubscription(
asyncStreamProxy,
consumerArn,
split.getShardId(),
split.getStartingPosition(),
- subscriptionTimeout);
+ subscriptionTimeout,
+ sharedShardSubscriptionExecutor,
+ sharedSubscriptionCallExecutor);
subscription.activateSubscription();
splitSubscriptions.put(split.splitId(), subscription);
}
}
+ /**
+ * Closes the split reader and releases all resources.
+ *
+ * The close method follows a specific order to ensure proper shutdown:
+ * 1. First, cancel all active subscriptions to prevent new events from being processed
+ * 2. Then, shutdown the shared executor service to stop processing existing events
+ * 3. Finally, close the async stream proxy to release network resources
+ *
+ *
This ordering is critical because:
+ * - Cancelling subscriptions first prevents new events from being submitted to the executor
+ * - Shutting down the executor next ensures all in-flight tasks complete or are cancelled
+ * - Closing the async stream proxy last ensures all resources are properly released after
+ * all processing has stopped
+ */
@Override
public void close() throws Exception {
- asyncStreamProxy.close();
+ cancelActiveSubscriptions();
+ shutdownSharedShardSubscriptionExecutor();
+ shutdownSharedSubscriptionCallExecutor();
+ closeAsyncStreamProxy();
+ }
+
+ /**
+ * Cancels all active subscriptions to prevent new events from being processed.
+ *
+ *
After cancelling subscriptions, we wait a short time to allow the cancellation
+ * signals to propagate before proceeding with executor shutdown.
+ */
+ private void cancelActiveSubscriptions() {
+ for (FanOutKinesisShardSubscription subscription : splitSubscriptions.values()) {
+ if (subscription.isActive()) {
+ try {
+ subscription.cancelSubscription();
+ } catch (Exception e) {
+ LOG.warn("Error cancelling subscription for shard {}",
+ subscription.getShardId(), e);
+ }
+ }
+ }
+
+ // Wait a short time (200ms) to allow cancellation signals to propagate
+ // This helps ensure that no new tasks are submitted to the executor after we begin shutdown
+ try {
+ Thread.sleep(200);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ /**
+ * Shuts down the shared executor service used for processing subscription events.
+ *
+ *
We use the EFO_DEREGISTER_CONSUMER_TIMEOUT (10 seconds) as the shutdown timeout
+ * to maintain consistency with other deregistration operations in the connector.
+ */
+ private void shutdownSharedShardSubscriptionExecutor() {
+ if (sharedShardSubscriptionExecutor == null) {
+ return;
+ }
+
+ sharedShardSubscriptionExecutor.shutdown();
+ try {
+ // Use the deregister consumer timeout (10 seconds)
+ // This timeout is consistent with other deregistration operations in the connector
+ if (!sharedShardSubscriptionExecutor.awaitTermination(
+ deregisterTimeout.toMillis(),
+ TimeUnit.MILLISECONDS)) {
+ LOG.warn("Event processing executor did not terminate in the specified time. Forcing shutdown.");
+ sharedShardSubscriptionExecutor.shutdownNow();
+ }
+ } catch (InterruptedException e) {
+ LOG.warn("Interrupted while waiting for event processing executor shutdown", e);
+ sharedShardSubscriptionExecutor.shutdownNow();
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ /**
+ * Shuts down the shared executor service used for subscription API calls.
+ *
+ *
We use the EFO_DEREGISTER_CONSUMER_TIMEOUT (10 seconds) as the shutdown timeout
+ * to maintain consistency with other deregistration operations in the connector.
+ */
+ private void shutdownSharedSubscriptionCallExecutor() {
+ if (sharedSubscriptionCallExecutor == null) {
+ return;
+ }
+
+ sharedSubscriptionCallExecutor.shutdown();
+ try {
+ // Use a shorter timeout since these are just API calls
+ if (!sharedSubscriptionCallExecutor.awaitTermination(
+ deregisterTimeout.toMillis(),
+ TimeUnit.MILLISECONDS)) {
+ LOG.warn("Subscription call executor did not terminate in the specified time. Forcing shutdown.");
+ sharedSubscriptionCallExecutor.shutdownNow();
+ }
+ } catch (InterruptedException e) {
+ LOG.warn("Interrupted while waiting for subscription call executor shutdown", e);
+ sharedSubscriptionCallExecutor.shutdownNow();
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ /**
+ * Closes the async stream proxy with a timeout.
+ *
+ *
We use the EFO_CONSUMER_SUBSCRIPTION_TIMEOUT (60 seconds) as the close timeout
+ * since closing the client involves similar network operations as subscription.
+ * The longer timeout accounts for potential network delays during shutdown.
+ */
+ private void closeAsyncStreamProxy() {
+ // Create a dedicated single-threaded executor for closing the asyncStreamProxy
+ // This prevents the close operation from being affected by the main executor shutdown
+ ExecutorService closeExecutor = new ThreadPoolExecutor(
+ 1, 1,
+ 0L, TimeUnit.MILLISECONDS,
+ new LinkedBlockingQueue<>(),
+ new ExecutorThreadFactory("kinesis-client-close"));
+
+ try {
+ // Submit the close task to the executor to avoid blocking the main thread
+ Future> closeFuture = closeExecutor.submit(() -> {
+ try {
+ asyncStreamProxy.close();
+ } catch (IOException e) {
+ LOG.warn("Error closing async stream proxy", e);
+ }
+ });
+
+ try {
+ // Use the subscription timeout (60 seconds)
+ // This longer timeout is necessary because closing the AWS SDK client
+ // may involve waiting for in-flight network operations to complete
+ closeFuture.get(
+ subscriptionTimeout.toMillis(),
+ TimeUnit.MILLISECONDS);
+ } catch (TimeoutException e) {
+ LOG.warn("Timed out while closing async stream proxy", e);
+ } catch (InterruptedException e) {
+ LOG.warn("Interrupted while closing async stream proxy", e);
+ Thread.currentThread().interrupt();
+ } catch (ExecutionException e) {
+ LOG.warn("Error while closing async stream proxy", e.getCause());
+ }
+ } finally {
+ // Ensure the close executor is always shut down to prevent resource leaks
+ closeExecutor.shutdownNow();
+ }
}
}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java
index a299e50a..6da47ead 100644
--- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java
@@ -19,6 +19,7 @@
package org.apache.flink.connector.kinesis.source.reader.fanout;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException;
import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
import org.apache.flink.connector.kinesis.source.split.StartingPosition;
@@ -46,6 +47,7 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@@ -77,6 +79,14 @@ public class FanOutKinesisShardSubscription {
private final Duration subscriptionTimeout;
+ /** Executor service to run subscription event processing tasks. */
+ private final ExecutorService subscriptionEventProcessingExecutor;
+
+ /** Executor service to run subscription API calls. */
+ private final ExecutorService subscriptionCallExecutor;
+
+ // Lock removed as we're using method-level synchronization instead
+
// Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2
// record batches available on next read.
private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(2);
@@ -86,19 +96,83 @@ public class FanOutKinesisShardSubscription {
// Store the current starting position for this subscription. Will be updated each time new
// batch of records is consumed
private StartingPosition startingPosition;
+
+ /**
+ * Gets the current starting position for this subscription.
+ *
+ * @return The current starting position
+ */
+ public StartingPosition getStartingPosition() {
+ return startingPosition;
+ }
+
+ /**
+ * Checks if the subscription is active.
+ *
+ * @return true if the subscription is active, false otherwise
+ */
+ public boolean isActive() {
+ return subscriptionActive.get();
+ }
+
+ /**
+ * Gets the shard ID for this subscription.
+ *
+ * @return The shard ID
+ */
+ public String getShardId() {
+ return shardId;
+ }
+
+ /**
+ * Cancels this subscription.
+ * This is primarily used during shutdown to cancel active subscriptions.
+ *
+ * @return true if the subscription was active and was cancelled, false otherwise
+ */
+ public boolean cancelSubscription() {
+ if (!subscriptionActive.get()) {
+ LOG.debug("Skipping cancellation of inactive subscription for shard {}.", shardId);
+ return false;
+ }
+ subscriptionActive.set(false);
+ if (shardSubscriber != null) {
+ shardSubscriber.cancel();
+ return true;
+ } else {
+ LOG.warn("Cannot cancel subscription - shardSubscriber is null for shard {}", shardId);
+ return false;
+ }
+ }
+
private FanOutShardSubscriber shardSubscriber;
+ /**
+ * Creates a new FanOutKinesisShardSubscription with the specified parameters.
+ *
+ * @param kinesis The AsyncStreamProxy to use for Kinesis operations
+ * @param consumerArn The ARN of the consumer
+ * @param shardId The ID of the shard to subscribe to
+ * @param startingPosition The starting position for the subscription
+ * @param subscriptionTimeout The timeout for the subscription
+ * @param subscriptionEventProcessingExecutor The executor service to use for processing subscription events
+ * @param subscriptionCallExecutor The executor service to use for making subscription API calls
+ */
public FanOutKinesisShardSubscription(
AsyncStreamProxy kinesis,
String consumerArn,
String shardId,
StartingPosition startingPosition,
- Duration subscriptionTimeout) {
+ Duration subscriptionTimeout,
+ ExecutorService subscriptionEventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor) {
this.kinesis = kinesis;
this.consumerArn = consumerArn;
this.shardId = shardId;
this.startingPosition = startingPosition;
this.subscriptionTimeout = subscriptionTimeout;
+ this.subscriptionEventProcessingExecutor = subscriptionEventProcessingExecutor;
+ this.subscriptionCallExecutor = subscriptionCallExecutor;
}
/** Method to allow eager activation of the subscription. */
@@ -116,10 +190,21 @@ public void activateSubscription() {
// We have to use our own CountDownLatch to wait for subscription to be acquired because
// subscription event is tracked via the handler.
CountDownLatch waitForSubscriptionLatch = new CountDownLatch(1);
- shardSubscriber = new FanOutShardSubscriber(waitForSubscriptionLatch);
+
+ // Create a local variable for the new subscriber to prevent a potential race condition
+ // where the shardSubscriber field might be modified by another thread between when we
+ // create the lambda and when it's executed. By using a local variable that is captured
+ // by the lambda, we ensure that the lambda always uses the subscriber instance that was
+ // created in this method call, regardless of any concurrent modifications to the
+ // shardSubscriber field.
+ FanOutShardSubscriber newSubscriber = new FanOutShardSubscriber(waitForSubscriptionLatch);
+ shardSubscriber = newSubscriber;
+
SubscribeToShardResponseHandler responseHandler =
SubscribeToShardResponseHandler.builder()
- .subscriber(() -> shardSubscriber)
+ // Use the local variable in the lambda to ensure we're always using the
+ // subscriber instance created in this method call
+ .subscriber(() -> newSubscriber)
.onError(
throwable -> {
// Errors that occur when obtaining a subscription are thrown
@@ -132,27 +217,49 @@ public void activateSubscription() {
})
.build();
- // We don't need to keep track of the future here because we monitor subscription success
- // using our own CountDownLatch
- kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler)
- .exceptionally(
- throwable -> {
- // If consumer exists and is still activating, we want to countdown.
- if (ExceptionUtils.findThrowable(
- throwable, ResourceInUseException.class)
- .isPresent()) {
- waitForSubscriptionLatch.countDown();
- return null;
- }
- LOG.error(
- "Error subscribing to shard {} with starting position {} for consumer {}.",
- shardId,
- startingPosition,
- consumerArn,
- throwable);
- terminateSubscription(throwable);
- return null;
- });
+ // Use the executor service to make the subscription call
+ // This offloads the potentially blocking API call to a dedicated thread pool,
+ // preventing it from blocking the main thread or the event processing threads.
+ // This separation is crucial to avoid potential deadlocks that could occur when
+ // the Netty event loop thread (used by the AWS SDK) needs to handle both the
+ // subscription call and the resulting events.
+ CompletableFuture subscriptionFuture = CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ LOG.debug("Making subscribeToShard API call for shard {} on thread {}",
+ shardId, Thread.currentThread().getName());
+
+ // Make the API call using the provided executor
+ return kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler);
+ } catch (Exception e) {
+ // Handle any exceptions that occur during the API call
+ LOG.error("Exception during subscribeToShard API call for shard {}", shardId, e);
+ terminateSubscription(e);
+ waitForSubscriptionLatch.countDown();
+ return CompletableFuture.completedFuture(null);
+ }
+ },
+ subscriptionCallExecutor
+ ).thenCompose(future -> future); // Flatten the CompletableFuture> to CompletableFuture
+
+ subscriptionFuture.exceptionally(
+ throwable -> {
+ // If consumer exists and is still activating, we want to countdown.
+ if (ExceptionUtils.findThrowable(
+ throwable, ResourceInUseException.class)
+ .isPresent()) {
+ waitForSubscriptionLatch.countDown();
+ return null;
+ }
+ LOG.error(
+ "Error subscribing to shard {} with starting position {} for consumer {}.",
+ shardId,
+ startingPosition,
+ consumerArn,
+ throwable);
+ terminateSubscription(throwable);
+ return null;
+ });
// We have to handle timeout for subscriptions separately because Java 8 does not support a
// fluent orTimeout() methods on CompletableFuture.
@@ -192,10 +299,12 @@ public void activateSubscription() {
private void terminateSubscription(Throwable t) {
if (!subscriptionException.compareAndSet(null, t)) {
LOG.warn(
- "Another subscription exception has been queued, ignoring subsequent exceptions",
+ "Another subscription exception has been queued for shard {}, ignoring subsequent exceptions",
+ shardId,
t);
}
- shardSubscriber.cancel();
+
+ cancelSubscription();
}
/**
@@ -226,15 +335,17 @@ public SubscribeToShardEvent nextEvent() {
.findFirst();
if (recoverableException.isPresent()) {
LOG.warn(
- "Recoverable exception encountered while subscribing to shard. Ignoring.",
+ "Recoverable exception encountered while subscribing to shard {}. Ignoring.",
+ shardId,
recoverableException.get());
- shardSubscriber.cancel();
+
+ cancelSubscription();
activateSubscription();
return null;
}
- LOG.error("Subscription encountered unrecoverable exception.", throwable);
+ LOG.error("Subscription encountered unrecoverable exception for shard {}.", shardId, throwable);
throw new KinesisStreamsSourceException(
- "Subscription encountered unrecoverable exception.", throwable);
+ String.format("Subscription encountered unrecoverable exception for shard %s.", shardId), throwable);
}
if (!subscriptionActive.get()) {
@@ -262,17 +373,21 @@ private FanOutShardSubscriber(CountDownLatch subscriptionLatch) {
}
public void requestRecords() {
- subscription.request(1);
+ if (subscription != null) {
+ subscription.request(1);
+ } else {
+ LOG.warn("Cannot request records - subscription is null for shard {}", shardId);
+ }
}
public void cancel() {
- if (!subscriptionActive.get()) {
- LOG.warn("Trying to cancel inactive subscription. Ignoring.");
- return;
- }
+ // Set subscription inactive - this is now handled in cancelSubscription()
+ // but we keep it here as well for safety
subscriptionActive.set(false);
if (subscription != null) {
subscription.cancel();
+ } else {
+ LOG.debug("Subscription already null during cancellation for shard {}", shardId);
}
}
@@ -293,27 +408,14 @@ public void onNext(SubscribeToShardEventStream subscribeToShardEventStream) {
new SubscribeToShardResponseHandler.Visitor() {
@Override
public void visit(SubscribeToShardEvent event) {
- try {
- LOG.debug(
- "Received event: {}, {}",
- event.getClass().getSimpleName(),
- event);
- eventQueue.put(event);
-
- // Update the starting position in case we have to recreate the
- // subscription
- startingPosition =
- StartingPosition.continueFromSequenceNumber(
- event.continuationSequenceNumber());
-
- // Replace the record just consumed in the Queue
- requestRecords();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new KinesisStreamsSourceException(
- "Interrupted while adding Kinesis record to internal buffer.",
- e);
- }
+ // For critical path operations like processing subscription events, we need to ensure:
+ // 1. Events are processed in order (sequential processing)
+ // 2. No events are dropped (reliable processing)
+ // 3. The Netty event loop thread is not blocked (async processing)
+ // 4. The starting position is correctly updated for checkpointing
+
+ // Submit the event processing to the executor service
+ submitEventProcessingTask(event);
}
});
}
@@ -334,4 +436,135 @@ public void onComplete() {
activateSubscription();
}
}
+
+ /**
+ * Helper method to determine if shutdown is in progress.
+ *
+ * @return true if shutdown is in progress, false otherwise
+ */
+ private boolean isShutdownInProgress() {
+ // Check if the executor service is shutting down or terminated
+ // This is the most reliable way to detect if shutdown has been initiated
+ return subscriptionEventProcessingExecutor.isShutdown() ||
+ subscriptionEventProcessingExecutor.isTerminated();
+ }
+
+ /**
+ * Submits an event processing task to the executor service.
+ * This method encapsulates the task submission logic and error handling.
+ *
+ * @param event The subscription event to process
+ */
+ private void submitEventProcessingTask(SubscribeToShardEvent event) {
+ try {
+ // Check if shutdown is in progress before submitting new tasks
+ // This prevents tasks from being submitted to a shutting down executor
+ if (isShutdownInProgress()) {
+ LOG.info("Shutdown in progress, not submitting new event processing task for shard {}", shardId);
+ return;
+ }
+
+ subscriptionEventProcessingExecutor.execute(() -> {
+ try {
+ // Process the event
+ processSubscriptionEvent(event);
+ } catch (Exception e) {
+ // Only log as error if we're not in shutdown mode
+ if (!isShutdownInProgress()) {
+ LOG.error("Error processing subscription event", e);
+ // Propagate the exception to the subscription exception handler
+ terminateSubscription(new KinesisStreamsSourceException(
+ "Error processing subscription event", e));
+ } else {
+ LOG.info("Error during shutdown while processing event for shard {} - ignoring", shardId, e);
+ }
+ }
+ });
+ } catch (Exception e) {
+ // This should never happen with an unbounded queue, but if it does,
+ // we need to propagate the exception to cause a Flink job restart
+ LOG.error("Error submitting subscription event task", e);
+ throw new KinesisStreamsSourceException(
+ "Error submitting subscription event task", e);
+ }
+ }
+
+ /**
+ * Processes a subscription event in a separate thread from the shared executor pool.
+ * This method encapsulates the critical path operations:
+ * 1. Putting the event in the blocking queue (which has a capacity of 2)
+ * 2. Updating the starting position for recovery after failover
+ * 3. Requesting more records
+ *
+ * These operations are executed sequentially for each shard to ensure thread safety
+ * and prevent race conditions.
+ *
+ * @param event The subscription event to process
+ */
+ @VisibleForTesting
+ synchronized void processSubscriptionEvent(SubscribeToShardEvent event) {
+ // Check if the thread is interrupted before doing any work
+ // This prevents unnecessary processing during shutdown
+ if (Thread.currentThread().isInterrupted()) {
+ // During normal operation, an interruption is unexpected and should be treated as an error
+ // During shutdown, it's expected and can be handled gracefully
+ if (!isShutdownInProgress()) {
+ LOG.error("Thread interrupted unexpectedly before processing event for shard {}", shardId);
+ throw new KinesisStreamsSourceException(
+ "Thread interrupted unexpectedly before processing event for shard " + shardId,
+ new InterruptedException());
+ } else {
+ LOG.info("Thread interrupted during shutdown before processing event for shard {} - skipping processing", shardId);
+ return;
+ }
+ }
+
+ try {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "Processing event for shard {}: {}, {}",
+ shardId,
+ event.getClass().getSimpleName(),
+ event);
+ }
+
+ // Put event in queue - this is a blocking operation
+ eventQueue.put(event);
+
+ // Update the starting position to ensure we can recover after failover
+ if (event.continuationSequenceNumber() != null) {
+ startingPosition = StartingPosition.continueFromSequenceNumber(
+ event.continuationSequenceNumber());
+ } else {
+ LOG.warn("Received null continuation sequence number for shard {}", shardId);
+ }
+
+ // Request more records
+ if (shardSubscriber != null) {
+ shardSubscriber.requestRecords();
+ } else {
+ LOG.warn("Cannot request more records - shardSubscriber is null for shard {}", shardId);
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "Successfully processed event for shard {}, updated position to {}",
+ shardId,
+ startingPosition);
+ }
+ } catch (InterruptedException e) {
+ // Log that we're handling an interruption during shutdown
+ LOG.info("Interrupted while adding Kinesis record to internal buffer for shard {} - this is expected during shutdown", shardId);
+
+ // Restore the interrupt status
+ Thread.currentThread().interrupt();
+
+ // During shutdown, we don't want to throw an exception that would cause a job failure
+ // Only throw if we're not in a shutdown context
+ if (!isShutdownInProgress()) {
+ throw new KinesisStreamsSourceException(
+ "Interrupted while adding Kinesis record to internal buffer.", e);
+ }
+ }
+ }
}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java
new file mode 100644
index 00000000..0664c820
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition;
+import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
+import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
+import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mockito;
+import software.amazon.awssdk.services.kinesis.model.Record;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for the happy path flow in {@link FanOutKinesisShardSubscription}
+ * and {@link FanOutKinesisShardSplitReader}.
+ */
+public class FanOutKinesisShardHappyPathTest extends FanOutKinesisShardTestBase {
+
+ /**
+ * Tests the basic happy path flow for a single shard.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testBasicHappyPathSingleShard() throws Exception {
+ // Create a metrics map for the shard
+ Map metricsMap = new HashMap<>();
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+ metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup));
+
+ // Create a reader
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add a split to the reader
+ reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split)));
+
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
+ // Verify that the subscription was activated
+ ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class);
+
+ verify(mockAsyncStreamProxy, times(1)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ shardIdCaptor.capture(),
+ startingPositionCaptor.capture(),
+ any());
+
+ // Verify the subscription parameters
+ assertThat(shardIdCaptor.getValue()).isEqualTo(SHARD_ID);
+ assertThat(startingPositionCaptor.getValue()).isEqualTo(StartingPosition.fromStart());
+ }
+
+ /**
+ * Tests the happy path flow for multiple shards.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testBasicHappyPathMultipleShards() throws Exception {
+ // Create metrics map for the shards
+ Map metricsMap = new HashMap<>();
+
+ KinesisShardSplit split1 = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID_1,
+ StartingPosition.fromStart());
+
+ KinesisShardSplit split2 = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID_2,
+ StartingPosition.fromStart());
+
+ metricsMap.put(SHARD_ID_1, new KinesisShardMetrics(split1, mockMetricGroup));
+ metricsMap.put(SHARD_ID_2, new KinesisShardMetrics(split2, mockMetricGroup));
+
+ // Create a reader
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add splits to the reader
+ List splits = new ArrayList<>();
+ splits.add(split1);
+ splits.add(split2);
+ reader.handleSplitsChanges(new SplitsAddition<>(splits));
+
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
+ // Verify that subscriptions were activated for both shards
+ ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class);
+
+ verify(mockAsyncStreamProxy, times(2)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ shardIdCaptor.capture(),
+ startingPositionCaptor.capture(),
+ any());
+
+ // Verify the subscription parameters
+ List capturedShardIds = shardIdCaptor.getAllValues();
+ assertThat(capturedShardIds).containsExactlyInAnyOrder(SHARD_ID_1, SHARD_ID_2);
+
+ List capturedStartingPositions = startingPositionCaptor.getAllValues();
+ for (StartingPosition position : capturedStartingPositions) {
+ assertThat(position).isEqualTo(StartingPosition.fromStart());
+ }
+ }
+
+ /**
+ * Tests the basic happy path flow with record processing for a single shard.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testBasicHappyPathWithRecordProcessing() throws Exception {
+ // Create a blocking queue to store processed records
+ BlockingQueue processedRecords = new LinkedBlockingQueue<>();
+
+ // Create a custom TestableSubscription that captures processed records
+ TestableSubscription testSubscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ processedRecords);
+
+ // Create test events with records in a specific order
+ int numEvents = 3;
+ int recordsPerEvent = 5;
+ List> eventRecords = new ArrayList<>();
+
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum));
+ }
+ eventRecords.add(records);
+ }
+
+ // Process the events
+ for (int i = 0; i < numEvents; i++) {
+ String sequenceNumber = "sequence-" + i;
+ testSubscription.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(sequenceNumber, eventRecords.get(i)));
+ }
+
+ // Verify that all records were processed in the correct order
+ List allProcessedRecords = new ArrayList<>();
+ processedRecords.drainTo(allProcessedRecords);
+
+ assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent);
+
+ // Verify the order of records
+ for (int i = 0; i < numEvents * recordsPerEvent; i++) {
+ String expectedData = "record-" + i;
+ String actualData = new String(
+ allProcessedRecords.get(i).data().asByteArray(),
+ java.nio.charset.StandardCharsets.UTF_8);
+ assertThat(actualData).isEqualTo(expectedData);
+ }
+
+ // Verify that the starting position was updated correctly
+ assertThat(testSubscription.getStartingPosition().getStartingMarker())
+ .isEqualTo("sequence-" + (numEvents - 1));
+ }
+
+ /**
+ * Tests that metrics are properly updated during record processing.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testMetricsUpdatedDuringProcessing() throws Exception {
+ // Create a metrics map for the shard
+ Map metricsMap = new HashMap<>();
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+ KinesisShardMetrics spyMetrics = Mockito.spy(new KinesisShardMetrics(split, mockMetricGroup));
+ metricsMap.put(SHARD_ID, spyMetrics);
+
+ // Create a test event with millisBehindLatest set
+ long millisBehindLatest = 1000L;
+
+ // Directly update the metrics
+ spyMetrics.setMillisBehindLatest(millisBehindLatest);
+
+ // Verify that the metrics were updated
+ verify(spyMetrics, times(1)).setMillisBehindLatest(millisBehindLatest);
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java
new file mode 100644
index 00000000..d55b4d3a
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds;
+import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition;
+import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
+import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
+import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import software.amazon.awssdk.services.kinesis.model.Record;
+import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests to verify that there's no dropping of records or change in order of records
+ * when processing events in {@link FanOutKinesisShardSubscription} and {@link FanOutKinesisShardSplitReader}.
+ */
+public class FanOutKinesisShardRecordOrderingTest extends FanOutKinesisShardTestBase {
+
+ /**
+ * Tests that records are processed in the correct order for a single shard.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testRecordOrderingPreservedForSingleShard() throws Exception {
+ // Create a blocking queue to store processed records
+ BlockingQueue processedRecords = new LinkedBlockingQueue<>();
+
+ // Create a custom TestableSubscription that captures processed records
+ TestableSubscription testSubscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ processedRecords);
+
+ // Create test events with records in a specific order
+ int numEvents = 3;
+ int recordsPerEvent = 5;
+ List> eventRecords = new ArrayList<>();
+
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum));
+ }
+ eventRecords.add(records);
+ }
+
+ // Process the events
+ for (int i = 0; i < numEvents; i++) {
+ String sequenceNumber = "sequence-" + i;
+ testSubscription.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(sequenceNumber, eventRecords.get(i)));
+ }
+
+ // Verify that all records were processed in the correct order
+ List allProcessedRecords = new ArrayList<>();
+ processedRecords.drainTo(allProcessedRecords);
+
+ assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent);
+
+ // Verify the order of records
+ for (int i = 0; i < numEvents * recordsPerEvent; i++) {
+ String expectedData = "record-" + i;
+ String actualData = new String(
+ allProcessedRecords.get(i).data().asByteArray(),
+ StandardCharsets.UTF_8);
+ assertThat(actualData).isEqualTo(expectedData);
+ }
+ }
+
+ /**
+ * Tests that records are processed in the correct order for multiple shards.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testRecordOrderingPreservedForMultipleShards() throws Exception {
+ // Create blocking queues to store processed records for each shard
+ BlockingQueue processedRecordsShard1 = new LinkedBlockingQueue<>();
+ BlockingQueue processedRecordsShard2 = new LinkedBlockingQueue<>();
+
+ // Create custom TestableSubscriptions for each shard
+ TestableSubscription subscription1 = createTestableSubscription(
+ SHARD_ID_1,
+ StartingPosition.fromStart(),
+ processedRecordsShard1);
+
+ TestableSubscription subscription2 = createTestableSubscription(
+ SHARD_ID_2,
+ StartingPosition.fromStart(),
+ processedRecordsShard2);
+
+ // Create test events with records in a specific order for each shard
+ int numEvents = 3;
+ int recordsPerEvent = 5;
+
+ // Process events for shard 1
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ records.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-" + recordNum));
+ }
+
+ String sequenceNumber = "shard1-sequence-" + i;
+ subscription1.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records));
+ }
+
+ // Process events for shard 2
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ records.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-" + recordNum));
+ }
+
+ String sequenceNumber = "shard2-sequence-" + i;
+ subscription2.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records));
+ }
+
+ // Verify that all records were processed in the correct order for shard 1
+ List allProcessedRecordsShard1 = new ArrayList<>();
+ processedRecordsShard1.drainTo(allProcessedRecordsShard1);
+
+ assertThat(allProcessedRecordsShard1).hasSize(numEvents * recordsPerEvent);
+
+ for (int i = 0; i < numEvents * recordsPerEvent; i++) {
+ String expectedData = "shard1-record-" + i;
+ String actualData = new String(
+ allProcessedRecordsShard1.get(i).data().asByteArray(),
+ StandardCharsets.UTF_8);
+ assertThat(actualData).isEqualTo(expectedData);
+ }
+
+ // Verify that all records were processed in the correct order for shard 2
+ List allProcessedRecordsShard2 = new ArrayList<>();
+ processedRecordsShard2.drainTo(allProcessedRecordsShard2);
+
+ assertThat(allProcessedRecordsShard2).hasSize(numEvents * recordsPerEvent);
+
+ for (int i = 0; i < numEvents * recordsPerEvent; i++) {
+ String expectedData = "shard2-record-" + i;
+ String actualData = new String(
+ allProcessedRecordsShard2.get(i).data().asByteArray(),
+ StandardCharsets.UTF_8);
+ assertThat(actualData).isEqualTo(expectedData);
+ }
+ }
+
+ /**
+ * Tests that records are not dropped when processing events.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testNoRecordsDropped() throws Exception {
+ // Create a reader with a single shard
+ FanOutKinesisShardSplitReader reader = createSplitReaderWithShard(SHARD_ID);
+
+ // Create a list to store fetched records
+ final List fetchedRecords = new ArrayList<>();
+
+ // Create a queue to simulate the event stream
+ BlockingQueue eventQueue = new LinkedBlockingQueue<>();
+
+ // Create a custom AsyncStreamProxy that will use our event queue
+ AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class);
+ when(customProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenAnswer(new Answer>() {
+ @Override
+ public CompletableFuture answer(InvocationOnMock invocation) {
+ Object[] args = invocation.getArguments();
+ software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler handler =
+ (software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler) args[3];
+
+ // Start a thread to feed events to the handler
+ new Thread(() -> {
+ try {
+ while (true) {
+ SubscribeToShardEvent event = eventQueue.poll(100, TimeUnit.MILLISECONDS);
+ if (event != null) {
+ // Create a TestableSubscription to process the event
+ TestableSubscription subscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ new LinkedBlockingQueue<>());
+
+ // Process the event directly
+ subscription.processSubscriptionEvent(event);
+
+ // Add the processed records to the fetchedRecords list
+ synchronized (fetchedRecords) {
+ for (Record record : event.records()) {
+ fetchedRecords.add(record);
+ }
+ }
+ }
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }).start();
+
+ return CompletableFuture.completedFuture(null);
+ }
+ });
+
+ // Create a reader with our custom proxy
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ final FanOutKinesisShardSplitReader customReader = new FanOutKinesisShardSplitReader(
+ customProxy,
+ CONSUMER_ARN,
+ Collections.emptyMap(),
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add a split to the reader
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+
+ customReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split)));
+
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
+ // Create test events with records
+ int numEvents = 5;
+ int recordsPerEvent = 10;
+ List allRecords = new ArrayList<>();
+
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ Record record = FanOutKinesisTestUtils.createTestRecord("record-" + recordNum);
+ records.add(record);
+ allRecords.add(record);
+ }
+
+ String sequenceNumber = "sequence-" + i;
+ eventQueue.add(FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records));
+ }
+
+ AtomicInteger fetchAttempts = new AtomicInteger(0);
+
+ // We need to fetch multiple times to get all records
+ while (fetchedRecords.size() < allRecords.size() && fetchAttempts.incrementAndGet() < 20) {
+ RecordsWithSplitIds recordsWithSplitIds = customReader.fetch();
+
+ // Extract records from the batch
+ String splitId;
+ while ((splitId = recordsWithSplitIds.nextSplit()) != null) {
+ Record record;
+ while ((record = recordsWithSplitIds.nextRecordFromSplit()) != null) {
+ fetchedRecords.add(record);
+ }
+ }
+
+ // Small delay to allow events to be processed
+ Thread.sleep(100);
+ }
+
+ // Verify that all records were fetched
+ assertThat(fetchedRecords).hasSameSizeAs(allRecords);
+
+ // Verify the content of each record
+ for (int i = 0; i < allRecords.size(); i++) {
+ String expectedData = new String(
+ allRecords.get(i).data().asByteArray(),
+ StandardCharsets.UTF_8);
+
+ // Find the matching record in the fetched records
+ boolean found = false;
+ for (Record fetchedRecord : fetchedRecords) {
+ String fetchedData = new String(
+ fetchedRecord.data().asByteArray(),
+ StandardCharsets.UTF_8);
+
+ if (fetchedData.equals(expectedData)) {
+ found = true;
+ break;
+ }
+ }
+
+ assertThat(found).as("Record %s was not found in fetched records", expectedData).isTrue();
+ }
+ }
+
+ /**
+ * Tests that records are processed in the correct order even when there are concurrent events.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testRecordOrderingWithConcurrentEvents() throws Exception {
+ // Create a blocking queue to store processed records
+ BlockingQueue processedRecords = new LinkedBlockingQueue<>();
+
+ // Create a custom TestableSubscription that captures processed records
+ TestableSubscription testSubscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ processedRecords);
+
+ // Create test events with records
+ int numEvents = 10;
+ int recordsPerEvent = 5;
+ List events = new ArrayList<>();
+
+ for (int i = 0; i < numEvents; i++) {
+ List records = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ int recordNum = i * recordsPerEvent + j;
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum));
+ }
+
+ String sequenceNumber = "sequence-" + i;
+ events.add(FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records));
+ }
+
+ // Process events concurrently
+ List> futures = new ArrayList<>();
+ for (SubscribeToShardEvent event : events) {
+ CompletableFuture future = CompletableFuture.runAsync(() -> {
+ testSubscription.processSubscriptionEvent(event);
+ }, testExecutor);
+ futures.add(future);
+ }
+
+ // Trigger all tasks in the executor
+ testExecutor.triggerAll();
+
+ // Wait for all events to be processed
+ CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).get();
+
+ // Verify that all records were processed
+ List allProcessedRecords = new ArrayList<>();
+ processedRecords.drainTo(allProcessedRecords);
+
+ assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent);
+
+ // Verify that all records were processed
+ List processedDataStrings = allProcessedRecords.stream()
+ .map(r -> new String(r.data().asByteArray(), StandardCharsets.UTF_8))
+ .collect(Collectors.toList());
+
+ // Create a list of all expected record data strings
+ List expectedDataStrings = new ArrayList<>();
+ for (int i = 0; i < numEvents; i++) {
+ for (int j = 0; j < recordsPerEvent; j++) {
+ expectedDataStrings.add("record-" + (i * recordsPerEvent + j));
+ }
+ }
+
+ // Verify that all expected records are present in the processed records
+ // We can't guarantee the exact order due to concurrency, but we can verify all records are there
+ assertThat(processedDataStrings).containsExactlyInAnyOrderElementsOf(expectedDataStrings);
+
+ // Verify that records from the same event are processed in order
+ // We do this by checking if there are any records from the same event that are out of order
+ boolean recordsInOrder = true;
+ for (int i = 0; i < numEvents; i++) {
+ List eventRecordIndices = new ArrayList<>();
+ for (int j = 0; j < recordsPerEvent; j++) {
+ String recordData = "record-" + (i * recordsPerEvent + j);
+ int index = processedDataStrings.indexOf(recordData);
+ eventRecordIndices.add(index);
+ }
+
+ // Check if the indices are in ascending order
+ for (int j = 1; j < eventRecordIndices.size(); j++) {
+ if (eventRecordIndices.get(j) < eventRecordIndices.get(j - 1)) {
+ recordsInOrder = false;
+ break;
+ }
+ }
+ }
+
+ // We expect records from the same event to be in order
+ assertThat(recordsInOrder).as("Records from the same event should be processed in order").isTrue();
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java
new file mode 100644
index 00000000..8bc65535
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition;
+import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
+import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
+import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
+import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeoutException;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for the restart behavior in {@link FanOutKinesisShardSubscription}
+ * and {@link FanOutKinesisShardSplitReader}.
+ */
+public class FanOutKinesisShardRestartTest extends FanOutKinesisShardTestBase {
+
+ /**
+ * Tests that when a restart happens, the correct starting position is used to reactivate the subscription.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testRestartUsesCorrectStartingPosition() throws Exception {
+ // Create a custom AsyncStreamProxy that will capture the starting position
+ AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class);
+ ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class);
+
+ when(customProxy.subscribeToShard(
+ any(String.class),
+ any(String.class),
+ startingPositionCaptor.capture(),
+ any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ // Create a metrics map for the shard
+ Map metricsMap = new HashMap<>();
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+ metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup));
+
+ // Create a reader
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader(
+ customProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add a split to the reader
+ reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split)));
+
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
+ // Verify that the subscription was activated with the initial starting position
+ verify(customProxy, times(1)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ eq(SHARD_ID),
+ any(StartingPosition.class),
+ any());
+
+ assertThat(startingPositionCaptor.getValue()).isEqualTo(StartingPosition.fromStart());
+
+ // Create a new split with the updated starting position
+ String continuationSequenceNumber = "sequence-after-processing";
+ StartingPosition updatedPosition = StartingPosition.continueFromSequenceNumber(continuationSequenceNumber);
+ KinesisShardSplit updatedSplit = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ updatedPosition);
+
+ // Simulate a restart by creating a new reader with the updated split
+ // Create a Configuration object and set the timeout
+ Configuration restartConfiguration = new Configuration();
+ restartConfiguration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader restartedReader = new FanOutKinesisShardSplitReader(
+ customProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ restartConfiguration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add the updated split to the restarted reader
+ restartedReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(updatedSplit)));
+
+ // Trigger the executor to execute the subscription tasks for the restarted reader
+ testExecutor.triggerAll();
+
+ // Verify that the subscription was reactivated with the updated starting position
+ verify(customProxy, times(2)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ eq(SHARD_ID),
+ any(StartingPosition.class),
+ any());
+
+ // Get the second captured value (from the restart)
+ StartingPosition capturedPosition = startingPositionCaptor.getAllValues().get(1);
+
+ // Verify it matches our expected updated position
+ assertThat(capturedPosition.getShardIteratorType()).isEqualTo(updatedPosition.getShardIteratorType());
+ assertThat(capturedPosition.getStartingMarker()).isEqualTo(updatedPosition.getStartingMarker());
+ }
+
+ /**
+ * Tests that when exceptions are thrown, the job is restarted.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testExceptionsProperlyHandled() throws Exception {
+ // Create a metrics map for the shard
+ Map metricsMap = new HashMap<>();
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+ metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup));
+
+ // Test with different types of exceptions
+ testExceptionHandling(software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException.builder().message("Resource not found").build(), true);
+ testExceptionHandling(new IOException("IO exception"), true);
+ testExceptionHandling(new TimeoutException("Timeout"), true);
+ testExceptionHandling(new RuntimeException("Runtime exception"), false);
+ }
+
+ /**
+ * Helper method to test exception handling.
+ */
+ private void testExceptionHandling(Exception exception, boolean isRecoverable) throws Exception {
+ // Create a metrics map for the shard
+ Map metricsMap = new HashMap<>();
+ KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart());
+ metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup));
+
+ // Create a mock AsyncStreamProxy that throws the specified exception
+ AsyncStreamProxy exceptionProxy = Mockito.mock(AsyncStreamProxy.class);
+ CompletableFuture failedFuture = new CompletableFuture<>();
+ failedFuture.completeExceptionally(exception);
+ when(exceptionProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenReturn(failedFuture);
+
+ // Create a reader with the exception-throwing proxy
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader(
+ exceptionProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Add a split to the reader
+ reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split)));
+
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
+ // If the exception is recoverable, the reader should try to reactivate the subscription
+ // If not, it should propagate the exception
+ if (isRecoverable) {
+ // Verify that the subscription was activated
+ verify(exceptionProxy, times(1)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ eq(SHARD_ID),
+ any(),
+ any());
+ } else {
+ // For non-recoverable exceptions, we expect them to be propagated
+ // This would typically cause the job to be restarted
+ // In a real scenario, this would be caught by Flink's error handling
+ verify(exceptionProxy, times(1)).subscribeToShard(
+ eq(CONSUMER_ARN),
+ eq(SHARD_ID),
+ any(),
+ any());
+ }
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java
index fbaaf696..eec78db4 100644
--- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java
@@ -49,7 +49,7 @@
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
/** Test for {@link FanOutKinesisShardSplitReader}. */
-public class FanOutKinesisShardSplitReaderTest {
+public class FanOutKinesisShardSplitReaderTest extends FanOutKinesisShardTestBase {
private static final String TEST_SHARD_ID = TestUtil.generateShardId(1);
FanOutKinesisShardSplitReader splitReader;
@@ -82,7 +82,10 @@ public void testNoAssignedSplitsHandledGracefully() throws Exception {
testAsyncStreamProxy,
CONSUMER_ARN,
shardMetricGroupMap,
- newConfigurationForTest());
+ newConfigurationForTest(),
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
RecordsWithSplitIds retrievedRecords = splitReader.fetch();
assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
@@ -99,10 +102,16 @@ public void testAssignedSplitHasNoRecordsHandledGracefully() throws Exception {
testAsyncStreamProxy,
CONSUMER_ARN,
shardMetricGroupMap,
- newConfigurationForTest());
+ newConfigurationForTest(),
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
// When fetching records
RecordsWithSplitIds retrievedRecords = splitReader.fetch();
@@ -122,10 +131,16 @@ public void testSplitWithExpiredShardHandledAsCompleted() throws Exception {
testAsyncStreamProxy,
CONSUMER_ARN,
shardMetricGroupMap,
- newConfigurationForTest());
+ newConfigurationForTest(),
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));
+ // Trigger the executor to execute the subscription tasks
+ testExecutor.triggerAll();
+
// When fetching records
RecordsWithSplitIds retrievedRecords = splitReader.fetch();
@@ -143,7 +158,10 @@ public void testWakeUpIsNoOp() {
testAsyncStreamProxy,
CONSUMER_ARN,
shardMetricGroupMap,
- newConfigurationForTest());
+ newConfigurationForTest(),
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
// When wakeup is called
// Then no exception is thrown and no-op
@@ -160,7 +178,10 @@ public void testCloseClosesStreamProxy() throws Exception {
trackCloseStreamProxy,
CONSUMER_ARN,
shardMetricGroupMap,
- newConfigurationForTest());
+ newConfigurationForTest(),
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
// When split reader is not closed
// Then stream proxy is still open
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java
new file mode 100644
index 00000000..7374cdde
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition;
+import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
+import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
+import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
+import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+import org.apache.flink.connector.kinesis.source.util.TestUtil;
+import org.apache.flink.core.testutils.ManuallyTriggeredScheduledExecutorService;
+import org.apache.flink.metrics.MetricGroup;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.Mockito;
+import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+
+import java.lang.reflect.Field;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for the thread pool behavior in {@link FanOutKinesisShardSplitReader}.
+ */
+public class FanOutKinesisShardSplitReaderThreadPoolTest {
+ private static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000);
+ private static final int NUM_SHARDS = 10;
+ private static final int EVENTS_PER_SHARD = 5;
+
+ private AsyncStreamProxy mockAsyncStreamProxy;
+ private FanOutKinesisShardSplitReader splitReader;
+
+ @BeforeEach
+ public void setUp() {
+ mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class);
+ when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+ }
+
+ /**
+ * Tests that the thread pool correctly processes events from multiple shards.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testThreadPoolProcessesMultipleShards() throws Exception {
+ // Use a counter to track events processed
+ AtomicInteger processedEvents = new AtomicInteger(0);
+ int expectedEvents = NUM_SHARDS * EVENTS_PER_SHARD;
+
+ // Create a manually triggered executor service
+ ManuallyTriggeredScheduledExecutorService testExecutor = new ManuallyTriggeredScheduledExecutorService();
+
+ // Create a map to store our test subscriptions
+ java.util.Map testSubscriptions = new java.util.HashMap<>();
+
+ // Create a custom subscription factory that creates test subscriptions
+ FanOutKinesisShardSplitReader.SubscriptionFactory customFactory =
+ (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> {
+ TestSubscription subscription = new TestSubscription(
+ proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor,
+ processedEvents, expectedEvents);
+ testSubscriptions.put(shardId, subscription);
+ return subscription;
+ };
+
+ // Create a metrics map for each shard
+ java.util.Map metricsMap = new java.util.HashMap<>();
+ for (int i = 0; i < NUM_SHARDS; i++) {
+ String shardId = SHARD_ID + "-" + i;
+ KinesisShardSplit split = new KinesisShardSplit(
+ STREAM_ARN,
+ shardId,
+ StartingPosition.fromStart(),
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ MetricGroup metricGroup = mock(MetricGroup.class);
+ when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup);
+ when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup);
+ metricsMap.put(shardId, new KinesisShardMetrics(split, metricGroup));
+ }
+
+ // Create a split reader with the custom factory and test executor
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ splitReader = new FanOutKinesisShardSplitReader(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ customFactory,
+ testExecutor,
+ testExecutor);
+
+ // Add multiple splits to the reader
+ List splits = new ArrayList<>();
+ for (int i = 0; i < NUM_SHARDS; i++) {
+ String shardId = SHARD_ID + "-" + i;
+ KinesisShardSplit split = new KinesisShardSplit(
+ STREAM_ARN,
+ shardId,
+ StartingPosition.fromStart(),
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ splits.add(split);
+ }
+ splitReader.handleSplitsChanges(new SplitsAddition<>(splits));
+
+ // Trigger all tasks in the executor to process subscription activations
+ testExecutor.triggerAll();
+
+ // Process all events for all shards by directly calling nextEvent() on each subscription
+ for (int i = 0; i < EVENTS_PER_SHARD; i++) {
+ for (String shardId : testSubscriptions.keySet()) {
+ TestSubscription subscription = testSubscriptions.get(shardId);
+ // Force the subscription to process an event
+ SubscribeToShardEvent event = subscription.nextEvent();
+ // Trigger all tasks in the executor after each event
+ testExecutor.triggerAll();
+ }
+ }
+
+ // Verify that all events were processed
+ assertThat(processedEvents.get()).as("All events should be processed").isEqualTo(expectedEvents);
+ }
+
+ /**
+ * Tests that the thread pool has natural flow control that prevents queue overflow.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testThreadPoolFlowControl() throws Exception {
+ // Create a counter to track the maximum number of queued tasks
+ AtomicInteger maxQueuedTasks = new AtomicInteger(0);
+ AtomicInteger currentQueuedTasks = new AtomicInteger(0);
+
+ // Create a custom AsyncStreamProxy that will delay subscription events
+ AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class);
+ when(customProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ // Create a metrics map for each shard
+ java.util.Map metricsMap = new java.util.HashMap<>();
+ for (int i = 0; i < NUM_SHARDS; i++) {
+ String shardId = SHARD_ID + "-" + i;
+ KinesisShardSplit split = new KinesisShardSplit(
+ STREAM_ARN,
+ shardId,
+ StartingPosition.fromStart(),
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ MetricGroup metricGroup = mock(MetricGroup.class);
+ when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup);
+ when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup);
+ metricsMap.put(shardId, new KinesisShardMetrics(split, metricGroup));
+ }
+
+ // Create a split reader
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ splitReader = new FanOutKinesisShardSplitReader(
+ customProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ Executors.newCachedThreadPool(),
+ Executors.newCachedThreadPool());
+
+ // Get access to the event processing executor service
+ ExecutorService executor = getEventProcessingExecutorService(splitReader);
+ assertThat(executor).isInstanceOf(ThreadPoolExecutor.class);
+ ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executor;
+
+ // Monitor the queue size
+ Thread monitorThread = new Thread(() -> {
+ try {
+ while (!Thread.currentThread().isInterrupted()) {
+ int queueSize = threadPoolExecutor.getQueue().size();
+ currentQueuedTasks.set(queueSize);
+ maxQueuedTasks.updateAndGet(current -> Math.max(current, queueSize));
+ Thread.sleep(10);
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ });
+ monitorThread.start();
+
+ // Create a custom subscription factory that adds artificial delay
+ FanOutKinesisShardSplitReader.SubscriptionFactory customFactory =
+ (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> {
+ return new FanOutKinesisShardSubscription(
+ proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) {
+ @Override
+ public void processSubscriptionEvent(SubscribeToShardEvent event) {
+ try {
+ // Add artificial delay to simulate processing time
+ Thread.sleep(50);
+ super.processSubscriptionEvent(event);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ @Override
+ public SubscribeToShardEvent nextEvent() {
+ // Create a test event
+ return createTestEvent("sequence-" + shardId);
+ }
+ };
+ };
+
+ // Set the custom factory using reflection
+ setSubscriptionFactory(splitReader, customFactory);
+
+ // Add multiple splits to the reader
+ List splits = new ArrayList<>();
+ for (int i = 0; i < NUM_SHARDS; i++) {
+ String shardId = SHARD_ID + "-" + i;
+ KinesisShardSplit split = new KinesisShardSplit(
+ STREAM_ARN,
+ shardId,
+ StartingPosition.fromStart(),
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ splits.add(split);
+ }
+ splitReader.handleSplitsChanges(new SplitsAddition<>(splits));
+
+ // Fetch records multiple times to trigger event processing
+ for (int i = 0; i < EVENTS_PER_SHARD * 2; i++) {
+ for (int j = 0; j < NUM_SHARDS; j++) {
+ splitReader.fetch();
+ }
+ }
+
+ // Wait for some time to allow tasks to be queued and processed
+ Thread.sleep(1000);
+
+ // Stop the monitor thread
+ monitorThread.interrupt();
+ monitorThread.join(1000);
+
+ // Verify that the maximum queue size is bounded
+ // The theoretical maximum is 2 * NUM_SHARDS (each subscription has a queue of 2)
+ assertThat(maxQueuedTasks.get()).isLessThanOrEqualTo(2 * NUM_SHARDS);
+ }
+
+ /**
+ * Tests that the thread pool is properly shut down when the split reader is closed.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testThreadPoolShutdown() throws Exception {
+ // Create a metrics map for the test
+ java.util.Map metricsMap = new java.util.HashMap<>();
+ KinesisShardSplit split = new KinesisShardSplit(
+ STREAM_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ MetricGroup metricGroup = mock(MetricGroup.class);
+ when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup);
+ when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup);
+ metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, metricGroup));
+
+ // Create a split reader
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ splitReader = new FanOutKinesisShardSplitReader(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ metricsMap,
+ configuration,
+ createTestSubscriptionFactory(),
+ Executors.newCachedThreadPool(),
+ Executors.newCachedThreadPool());
+
+ // Get access to the executor services
+ ExecutorService eventProcessingExecutor = getEventProcessingExecutorService(splitReader);
+ ExecutorService subscriptionCallExecutor = getSubscriptionCallExecutorService(splitReader);
+ assertThat(eventProcessingExecutor).isNotNull();
+ assertThat(subscriptionCallExecutor).isNotNull();
+
+ // Close the split reader
+ splitReader.close();
+
+ // Verify that both executor services are shut down
+ assertThat(eventProcessingExecutor.isShutdown()).isTrue();
+ assertThat(subscriptionCallExecutor.isShutdown()).isTrue();
+ }
+
+ /**
+ * Creates a test SubscribeToShardEvent with the given continuation sequence number.
+ */
+ private SubscribeToShardEvent createTestEvent(String continuationSequenceNumber) {
+ return SubscribeToShardEvent.builder()
+ .continuationSequenceNumber(continuationSequenceNumber)
+ .millisBehindLatest(0L)
+ .records(new ArrayList<>())
+ .build();
+ }
+
+ /**
+ * Gets the event processing executor service from the split reader using reflection.
+ */
+ private ExecutorService getEventProcessingExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception {
+ Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("sharedShardSubscriptionExecutor");
+ field.setAccessible(true);
+ return (ExecutorService) field.get(splitReader);
+ }
+
+ /**
+ * Gets the subscription call executor service from the split reader using reflection.
+ */
+ private ExecutorService getSubscriptionCallExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception {
+ Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("sharedSubscriptionCallExecutor");
+ field.setAccessible(true);
+ return (ExecutorService) field.get(splitReader);
+ }
+
+ /**
+ * Sets the subscription factory in the split reader using reflection.
+ */
+ private void setSubscriptionFactory(
+ FanOutKinesisShardSplitReader splitReader,
+ FanOutKinesisShardSplitReader.SubscriptionFactory factory) throws Exception {
+ Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("subscriptionFactory");
+ field.setAccessible(true);
+ field.set(splitReader, factory);
+ }
+
+ /**
+ * Creates a test subscription factory.
+ *
+ * @return A test subscription factory
+ */
+ private FanOutKinesisShardSplitReader.SubscriptionFactory createTestSubscriptionFactory() {
+ return (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) ->
+ new FanOutKinesisShardSubscription(
+ proxy,
+ consumerArn,
+ shardId,
+ startingPosition,
+ timeout,
+ eventProcessingExecutor,
+ subscriptionCallExecutor);
+ }
+
+ /**
+ * A test subscription that ensures we process exactly EVENTS_PER_SHARD events per shard.
+ */
+ private static class TestSubscription extends FanOutKinesisShardSubscription {
+ private final AtomicInteger eventsProcessed = new AtomicInteger(0);
+ private final AtomicInteger globalCounter;
+ private final int expectedTotal;
+ private final String shardId;
+
+ public TestSubscription(
+ AsyncStreamProxy proxy,
+ String consumerArn,
+ String shardId,
+ StartingPosition startingPosition,
+ Duration timeout,
+ ExecutorService eventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor,
+ AtomicInteger globalCounter,
+ int expectedTotal) {
+ super(proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor);
+ this.shardId = shardId;
+ this.globalCounter = globalCounter;
+ this.expectedTotal = expectedTotal;
+ }
+
+ @Override
+ public SubscribeToShardEvent nextEvent() {
+ int current = eventsProcessed.get();
+
+ // Only return events up to EVENTS_PER_SHARD
+ if (current < EVENTS_PER_SHARD) {
+ // Create a test event
+ SubscribeToShardEvent event = SubscribeToShardEvent.builder()
+ .continuationSequenceNumber("sequence-" + shardId + "-" + current)
+ .millisBehindLatest(0L)
+ .records(new ArrayList<>())
+ .build();
+
+ // Increment the counters
+ eventsProcessed.incrementAndGet();
+ int globalCount = globalCounter.incrementAndGet();
+ return event;
+ }
+
+ // Return null when we've processed all events for this shard
+ return null;
+ }
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java
new file mode 100644
index 00000000..a633d22a
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import software.amazon.awssdk.services.kinesis.model.Record;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for the starting position behavior in {@link FanOutKinesisShardSubscription}.
+ */
+public class FanOutKinesisShardStartingPositionTest extends FanOutKinesisShardTestBase {
+
+ /**
+ * Tests that the starting position is correctly recorded after queue.put for a single shard.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testStartingPositionRecordedAfterQueuePutSingleShard() throws Exception {
+ // Create a blocking queue to store processed records
+ BlockingQueue processedRecords = new LinkedBlockingQueue<>();
+
+ // Create a custom TestableSubscription that captures processed records
+ TestableSubscription testSubscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ processedRecords);
+
+ // Create a test event with records
+ String continuationSequenceNumber = "sequence-1";
+ List records = new ArrayList<>();
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-1"));
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-2"));
+
+ // Process the event
+ testSubscription.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber, records));
+
+ // Verify that all records were processed
+ List allProcessedRecords = new ArrayList<>();
+ processedRecords.drainTo(allProcessedRecords);
+ assertThat(allProcessedRecords).hasSize(2);
+
+ // Verify that the starting position was updated correctly
+ assertThat(testSubscription.getStartingPosition().getStartingMarker())
+ .isEqualTo(continuationSequenceNumber);
+ }
+
+ /**
+ * Tests that the starting position is correctly recorded after queue.put for multiple shards.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testStartingPositionRecordedAfterQueuePutMultipleShards() throws Exception {
+ // Create blocking queues to store processed records for each shard
+ BlockingQueue processedRecordsShard1 = new LinkedBlockingQueue<>();
+ BlockingQueue processedRecordsShard2 = new LinkedBlockingQueue<>();
+
+ // Create custom TestableSubscriptions for each shard
+ TestableSubscription subscription1 = createTestableSubscription(
+ SHARD_ID_1,
+ StartingPosition.fromStart(),
+ processedRecordsShard1);
+
+ TestableSubscription subscription2 = createTestableSubscription(
+ SHARD_ID_2,
+ StartingPosition.fromStart(),
+ processedRecordsShard2);
+
+ // Create test events with records for each shard
+ String continuationSequenceNumber1 = "sequence-shard1";
+ String continuationSequenceNumber2 = "sequence-shard2";
+
+ List recordsShard1 = new ArrayList<>();
+ recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-1"));
+ recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-2"));
+
+ List recordsShard2 = new ArrayList<>();
+ recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-1"));
+ recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-2"));
+
+ // Process the events
+ subscription1.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber1, recordsShard1));
+ subscription2.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber2, recordsShard2));
+
+ // Verify that all records were processed for shard 1
+ List allProcessedRecordsShard1 = new ArrayList<>();
+ processedRecordsShard1.drainTo(allProcessedRecordsShard1);
+ assertThat(allProcessedRecordsShard1).hasSize(2);
+
+ // Verify that all records were processed for shard 2
+ List allProcessedRecordsShard2 = new ArrayList<>();
+ processedRecordsShard2.drainTo(allProcessedRecordsShard2);
+ assertThat(allProcessedRecordsShard2).hasSize(2);
+
+ // Verify that the starting positions were updated correctly
+ assertThat(subscription1.getStartingPosition().getStartingMarker())
+ .isEqualTo(continuationSequenceNumber1);
+ assertThat(subscription2.getStartingPosition().getStartingMarker())
+ .isEqualTo(continuationSequenceNumber2);
+ }
+
+ /**
+ * Tests that the starting position is not recorded when queue.put fails for a single shard.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testStartingPositionNotRecordedWhenQueuePutFailsSingleShard() throws Exception {
+ // Create a custom TestableSubscription with a failing queue
+ TestableSubscription testSubscription = createTestableSubscription(
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ null); // Null queue will cause queue.put to be skipped
+
+ // Set the flag to not update starting position
+ testSubscription.setShouldUpdateStartingPosition(false);
+
+ // Create a test event with records
+ String continuationSequenceNumber = "sequence-1";
+ List records = new ArrayList<>();
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-1"));
+ records.add(FanOutKinesisTestUtils.createTestRecord("record-2"));
+
+ // Store the original starting position
+ StartingPosition originalPosition = testSubscription.getStartingPosition();
+
+ // Process the event
+ testSubscription.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber, records));
+
+ // Verify that the starting position was not updated
+ assertThat(testSubscription.getStartingPosition()).isEqualTo(originalPosition);
+ }
+
+ /**
+ * Tests that the starting position is not recorded when queue.put fails for multiple shards.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testStartingPositionNotRecordedWhenQueuePutFailsMultipleShards() throws Exception {
+ // Create custom TestableSubscriptions with failing queues
+ TestableSubscription subscription1 = createTestableSubscription(
+ SHARD_ID_1,
+ StartingPosition.fromStart(),
+ null); // Null queue will cause queue.put to be skipped
+
+ TestableSubscription subscription2 = createTestableSubscription(
+ SHARD_ID_2,
+ StartingPosition.fromStart(),
+ null); // Null queue will cause queue.put to be skipped
+
+ // Set the flags to not update starting positions
+ subscription1.setShouldUpdateStartingPosition(false);
+ subscription2.setShouldUpdateStartingPosition(false);
+
+ // Create test events with records for each shard
+ String continuationSequenceNumber1 = "sequence-shard1";
+ String continuationSequenceNumber2 = "sequence-shard2";
+
+ List recordsShard1 = new ArrayList<>();
+ recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-1"));
+ recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-2"));
+
+ List recordsShard2 = new ArrayList<>();
+ recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-1"));
+ recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-2"));
+
+ // Store the original starting positions
+ StartingPosition originalPosition1 = subscription1.getStartingPosition();
+ StartingPosition originalPosition2 = subscription2.getStartingPosition();
+
+ // Process the events
+ subscription1.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber1, recordsShard1));
+ subscription2.processSubscriptionEvent(
+ FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber2, recordsShard2));
+
+ // Verify that the starting positions were not updated
+ assertThat(subscription1.getStartingPosition()).isEqualTo(originalPosition1);
+ assertThat(subscription2.getStartingPosition()).isEqualTo(originalPosition2);
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java
new file mode 100644
index 00000000..102c184d
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException;
+import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN;
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for thread safety in {@link FanOutKinesisShardSubscription}.
+ */
+public class FanOutKinesisShardSubscriptionThreadSafetyTest {
+
+ private static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000);
+ private static final String TEST_CONTINUATION_SEQUENCE_NUMBER = "test-continuation-sequence-number";
+
+ private AsyncStreamProxy mockAsyncStreamProxy;
+ private ExecutorService testExecutor;
+ private FanOutKinesisShardSubscription subscription;
+
+ @BeforeEach
+ public void setUp() {
+ mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class);
+ when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ testExecutor = Executors.newFixedThreadPool(4);
+ }
+
+ /**
+ * Tests that events are processed sequentially, ensuring that the starting position
+ * is updated in the correct order.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testEventProcessingSequential() throws Exception {
+ // Create a custom TestableSubscription that doesn't require shardSubscriber to be initialized
+ TestableSubscription testSubscription = new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ TEST_SUBSCRIPTION_TIMEOUT,
+ testExecutor,
+ null);
+
+ // Create test events with different sequence numbers
+ List testEvents = new ArrayList<>();
+ for (int i = 1; i <= 5; i++) {
+ testEvents.add(createTestEvent("sequence-" + i));
+ }
+
+ // Process events sequentially
+ for (SubscribeToShardEvent event : testEvents) {
+ testSubscription.processSubscriptionEvent(event);
+ }
+
+ // Verify that the final starting position is based on the last event
+ assertThat(testSubscription.getStartingPosition().getStartingMarker())
+ .isEqualTo(testEvents.get(testEvents.size() - 1).continuationSequenceNumber());
+ }
+
+ /**
+ * Tests that the subscription event processing lock prevents concurrent processing of events.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testEventProcessingLock() throws Exception {
+ // Create a CountDownLatch to track when the first task starts
+ CountDownLatch firstTaskStarted = new CountDownLatch(1);
+
+ // Create a CountDownLatch to control when the first task completes
+ CountDownLatch allowFirstTaskToComplete = new CountDownLatch(1);
+
+ // Create a CountDownLatch to track when the second task completes
+ CountDownLatch secondTaskCompleted = new CountDownLatch(1);
+
+ // Create an AtomicInteger to track the order of execution
+ AtomicInteger executionOrder = new AtomicInteger(0);
+
+ // Create a custom executor that will help us control the execution order
+ ExecutorService customExecutor = Executors.newFixedThreadPool(2);
+
+ // Create a custom TestableSubscription with a synchronized processSubscriptionEvent method
+ TestableSubscription testSubscription = new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ TEST_SUBSCRIPTION_TIMEOUT,
+ customExecutor,
+ null) {
+
+ @Override
+ public synchronized void processSubscriptionEvent(SubscribeToShardEvent event) {
+ String sequenceNumber = event.continuationSequenceNumber();
+
+ if ("sequence-1".equals(sequenceNumber)) {
+ // First task signals it has started and waits for permission to complete
+ executionOrder.incrementAndGet(); // Should be 1
+ firstTaskStarted.countDown();
+ try {
+ allowFirstTaskToComplete.await();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ } else if ("sequence-2".equals(sequenceNumber)) {
+ // Second task just increments the counter and signals completion
+ executionOrder.incrementAndGet(); // Should be 2
+ secondTaskCompleted.countDown();
+ }
+
+ // Call the parent method
+ super.processSubscriptionEvent(event);
+ }
+ };
+
+ // Submit the first event
+ CompletableFuture future1 = CompletableFuture.runAsync(() -> {
+ testSubscription.processSubscriptionEvent(createTestEvent("sequence-1"));
+ });
+
+ // Wait for the first task to start
+ assertThat(firstTaskStarted.await(5, TimeUnit.SECONDS)).isTrue();
+
+ // Submit the second event
+ CompletableFuture future2 = CompletableFuture.runAsync(() -> {
+ testSubscription.processSubscriptionEvent(createTestEvent("sequence-2"));
+ });
+
+ // Allow some time for the second task to potentially start if there was no lock
+ Thread.sleep(500);
+
+ // The second task should not have executed yet due to the lock
+ assertThat(executionOrder.get()).isEqualTo(1);
+
+ // Allow the first task to complete
+ allowFirstTaskToComplete.countDown();
+
+ // Wait for the second task to complete
+ assertThat(secondTaskCompleted.await(5, TimeUnit.SECONDS)).isTrue();
+
+ // Verify the execution order
+ assertThat(executionOrder.get()).isEqualTo(2);
+
+ // Verify both futures completed
+ CompletableFuture.allOf(future1, future2).get(5, TimeUnit.SECONDS);
+ }
+
+ /**
+ * Tests that events are processed using the executor service.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testExecutorServiceUsage() throws Exception {
+ // Create a latch to track when the executor service is used
+ CountDownLatch executorUsed = new CountDownLatch(1);
+
+ // Create a custom executor that will signal when it's used
+ ExecutorService customExecutor = spy(testExecutor);
+ doAnswer(invocation -> {
+ executorUsed.countDown();
+ return invocation.callRealMethod();
+ }).when(customExecutor).execute(any(Runnable.class));
+
+ // Create a custom TestableSubscription that doesn't require shardSubscriber to be initialized
+ TestableSubscription testSubscription = new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ TEST_SUBSCRIPTION_TIMEOUT,
+ customExecutor,
+ null);
+
+ // Submit an event for processing
+ testSubscription.submitEventProcessingTask(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER));
+
+ // Verify that the executor was used
+ assertThat(executorUsed.await(5, TimeUnit.SECONDS)).isTrue();
+ verify(customExecutor, times(1)).execute(any(Runnable.class));
+ }
+
+ /**
+ * Tests that exceptions in event processing are properly propagated.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testExceptionPropagation() throws Exception {
+ // Create a custom TestableSubscription that throws a KinesisStreamsSourceException
+ TestableSubscription testSubscription = new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ TEST_SUBSCRIPTION_TIMEOUT,
+ testExecutor,
+ null) {
+
+ @Override
+ public void processSubscriptionEvent(SubscribeToShardEvent event) {
+ throw new KinesisStreamsSourceException("Test exception", new RuntimeException("Cause"));
+ }
+ };
+
+ // This should throw a KinesisStreamsSourceException
+ assertThatThrownBy(() -> {
+ testSubscription.processSubscriptionEvent(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER));
+ }).isInstanceOf(KinesisStreamsSourceException.class);
+ }
+
+ /**
+ * Tests that the starting position is updated only after the event is successfully added to the queue.
+ */
+ @Test
+ @Timeout(value = 30)
+ public void testStartingPositionUpdatedAfterQueuePut() throws Exception {
+ // Create a blocking queue that we can control
+ BlockingQueue controlledQueue = spy(new LinkedBlockingQueue<>(2));
+
+ // Create a latch to track when put is called
+ CountDownLatch putCalled = new CountDownLatch(1);
+
+ // Create a latch to control when put returns
+ CountDownLatch allowPutToReturn = new CountDownLatch(1);
+
+ // Create an atomic boolean to track if the starting position was updated before put completed
+ AtomicBoolean startingPositionUpdatedBeforePutCompleted = new AtomicBoolean(false);
+
+ // Mock the queue's put method to control its execution
+ doAnswer(new Answer() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ putCalled.countDown();
+ allowPutToReturn.await(5, TimeUnit.SECONDS);
+
+ // Call the real method
+ invocation.callRealMethod();
+ return null;
+ }
+ }).when(controlledQueue).put(any(SubscribeToShardEvent.class));
+
+ // Create a subscription with access to the controlled queue
+ FanOutKinesisShardSubscription testSubscription = new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ SHARD_ID,
+ StartingPosition.fromStart(),
+ TEST_SUBSCRIPTION_TIMEOUT,
+ testExecutor,
+ controlledQueue);
+
+ // Create a thread to check the starting position while put is blocked
+ Thread checkThread = new Thread(() -> {
+ try {
+ // Wait for put to be called
+ assertThat(putCalled.await(5, TimeUnit.SECONDS)).isTrue();
+
+ // Check if the starting position was updated before put completed
+ StartingPosition currentPosition = testSubscription.getStartingPosition();
+ if (currentPosition.getStartingMarker() != null &&
+ currentPosition.getStartingMarker().equals(TEST_CONTINUATION_SEQUENCE_NUMBER)) {
+ startingPositionUpdatedBeforePutCompleted.set(true);
+ }
+
+ // Allow put to return
+ allowPutToReturn.countDown();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ });
+
+ // Start the check thread
+ checkThread.start();
+
+ // Process an event
+ testSubscription.processSubscriptionEvent(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER));
+
+ // Wait for the check thread to complete
+ checkThread.join(5000);
+
+ // Verify that the starting position was not updated before put completed
+ assertThat(startingPositionUpdatedBeforePutCompleted.get()).isFalse();
+
+ // Verify that the starting position was updated after put completed
+ assertThat(testSubscription.getStartingPosition().getStartingMarker())
+ .isEqualTo(TEST_CONTINUATION_SEQUENCE_NUMBER);
+ }
+
+ /**
+ * Creates a test SubscribeToShardEvent with the given continuation sequence number.
+ */
+ private SubscribeToShardEvent createTestEvent(String continuationSequenceNumber) {
+ return SubscribeToShardEvent.builder()
+ .continuationSequenceNumber(continuationSequenceNumber)
+ .millisBehindLatest(0L)
+ .records(new ArrayList<>())
+ .build();
+ }
+
+ /**
+ * A testable version of FanOutKinesisShardSubscription that allows access to the event queue
+ * and overrides methods that require shardSubscriber to be initialized.
+ */
+ private static class TestableSubscription extends FanOutKinesisShardSubscription {
+ private final BlockingQueue testEventQueue;
+ private StartingPosition currentStartingPosition;
+
+ public TestableSubscription(
+ AsyncStreamProxy kinesis,
+ String consumerArn,
+ String shardId,
+ StartingPosition startingPosition,
+ Duration subscriptionTimeout,
+ ExecutorService subscriptionEventProcessingExecutor,
+ BlockingQueue testEventQueue) {
+ super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor, subscriptionEventProcessingExecutor);
+ this.testEventQueue = testEventQueue;
+ this.currentStartingPosition = startingPosition;
+ }
+
+ @Override
+ public StartingPosition getStartingPosition() {
+ return currentStartingPosition;
+ }
+
+ @Override
+ public void processSubscriptionEvent(SubscribeToShardEvent event) {
+ try {
+ if (testEventQueue != null) {
+ testEventQueue.put(event);
+ }
+
+ // Update the starting position to ensure we can recover after failover
+ currentStartingPosition = StartingPosition.continueFromSequenceNumber(
+ event.continuationSequenceNumber());
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new KinesisStreamsSourceException(
+ "Interrupted while adding Kinesis record to internal buffer.", e);
+ }
+ }
+
+ /**
+ * Public method to submit an event processing task directly to the executor.
+ * This is used for testing the executor service usage.
+ */
+ public void submitEventProcessingTask(SubscribeToShardEvent event) {
+ try {
+ // Use reflection to access the private executor field
+ java.lang.reflect.Field field = FanOutKinesisShardSubscription.class.getDeclaredField("subscriptionEventProcessingExecutor");
+ field.setAccessible(true);
+ ExecutorService executor = (ExecutorService) field.get(this);
+
+ executor.execute(() -> {
+ synchronized (this) {
+ processSubscriptionEvent(event);
+ }
+ });
+ } catch (Exception e) {
+ throw new KinesisStreamsSourceException(
+ "Error submitting subscription event task", e);
+ }
+ }
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java
new file mode 100644
index 00000000..8e2fc356
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions;
+import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+import org.apache.flink.core.testutils.ManuallyTriggeredScheduledExecutorService;
+import org.apache.flink.metrics.MetricGroup;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.mockito.Mockito;
+import software.amazon.awssdk.services.kinesis.model.Record;
+import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+
+import java.time.Duration;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+
+import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Base class for Kinesis shard tests.
+ */
+public abstract class FanOutKinesisShardTestBase {
+
+ protected static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000);
+ protected static final String SHARD_ID_1 = "shardId-000000000001";
+ protected static final String SHARD_ID_2 = "shardId-000000000002";
+ protected static final String CONSUMER_ARN = "abcdedf";
+
+ protected AsyncStreamProxy mockAsyncStreamProxy;
+ protected ManuallyTriggeredScheduledExecutorService testExecutor;
+ protected MetricGroup mockMetricGroup;
+
+ @BeforeEach
+ public void setUp() {
+ mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class);
+ when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ testExecutor = new ManuallyTriggeredScheduledExecutorService();
+
+ mockMetricGroup = mock(MetricGroup.class);
+ when(mockMetricGroup.addGroup(any(String.class))).thenReturn(mockMetricGroup);
+ when(mockMetricGroup.addGroup(any(String.class), any(String.class))).thenReturn(mockMetricGroup);
+ }
+
+ /**
+ * A testable version of FanOutKinesisShardSubscription that captures processed records.
+ */
+ protected static class TestableSubscription extends FanOutKinesisShardSubscription {
+ private final BlockingQueue recordQueue;
+ private volatile StartingPosition currentStartingPosition;
+ private volatile boolean shouldUpdateStartingPosition = true;
+
+ public TestableSubscription(
+ AsyncStreamProxy kinesis,
+ String consumerArn,
+ String shardId,
+ StartingPosition startingPosition,
+ Duration subscriptionTimeout,
+ ExecutorService subscriptionEventProcessingExecutor,
+ ExecutorService subscriptionCallExecutor,
+ BlockingQueue recordQueue) {
+ super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor, subscriptionCallExecutor);
+ this.recordQueue = recordQueue;
+ this.currentStartingPosition = startingPosition;
+ }
+
+ @Override
+ public StartingPosition getStartingPosition() {
+ return currentStartingPosition;
+ }
+
+ @Override
+ public void processSubscriptionEvent(SubscribeToShardEvent event) {
+ boolean recordsProcessed = false;
+
+ try {
+ // Add all records to the queue
+ if (recordQueue != null && event.records() != null) {
+ for (Record record : event.records()) {
+ recordQueue.put(record);
+ }
+ recordsProcessed = true;
+ }
+
+ // Update the starting position only if records were processed
+ if (recordsProcessed && shouldUpdateStartingPosition) {
+ String continuationSequenceNumber = event.continuationSequenceNumber();
+ if (continuationSequenceNumber != null) {
+ currentStartingPosition = StartingPosition.continueFromSequenceNumber(continuationSequenceNumber);
+ }
+ }
+
+ // Note: We're not calling super.processSubscriptionEvent(event) here
+ // because that would try to use the shardSubscriber which is null in our test
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted while processing event", e);
+ }
+ }
+
+ public void setShouldUpdateStartingPosition(boolean shouldUpdateStartingPosition) {
+ this.shouldUpdateStartingPosition = shouldUpdateStartingPosition;
+ }
+ }
+
+ /**
+ * Creates a TestableSubscription for testing.
+ *
+ * @param shardId The shard ID
+ * @param startingPosition The starting position
+ * @param recordQueue The queue to store processed records
+ * @return A TestableSubscription
+ */
+ protected TestableSubscription createTestableSubscription(
+ String shardId,
+ StartingPosition startingPosition,
+ BlockingQueue recordQueue) {
+ return new TestableSubscription(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ shardId,
+ startingPosition,
+ TEST_SUBSCRIPTION_TIMEOUT,
+ testExecutor,
+ testExecutor, // Use the same executor for subscription calls
+ recordQueue);
+ }
+
+ /**
+ * Creates a test subscription factory.
+ *
+ * @return A test subscription factory
+ */
+ protected FanOutKinesisShardSplitReader.SubscriptionFactory createTestSubscriptionFactory() {
+ return (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) ->
+ new FanOutKinesisShardSubscription(
+ proxy,
+ consumerArn,
+ shardId,
+ startingPosition,
+ timeout,
+ eventProcessingExecutor,
+ subscriptionCallExecutor);
+ }
+
+ /**
+ * Creates a FanOutKinesisShardSplitReader with a single shard.
+ *
+ * @param shardId The shard ID
+ * @return A FanOutKinesisShardSplitReader
+ */
+ protected FanOutKinesisShardSplitReader createSplitReaderWithShard(String shardId) {
+ // Create a Configuration object and set the timeout
+ Configuration configuration = new Configuration();
+ configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT);
+
+ FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader(
+ mockAsyncStreamProxy,
+ CONSUMER_ARN,
+ Mockito.mock(java.util.Map.class),
+ configuration,
+ createTestSubscriptionFactory(),
+ testExecutor,
+ testExecutor);
+
+ // Create a split
+ reader.handleSplitsChanges(
+ new org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition<>(
+ java.util.Collections.singletonList(
+ FanOutKinesisTestUtils.createTestSplit(
+ STREAM_ARN,
+ shardId,
+ StartingPosition.fromStart()))));
+
+ return reader;
+ }
+}
diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java
new file mode 100644
index 00000000..2e3c814b
--- /dev/null
+++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kinesis.source.reader.fanout;
+
+import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
+import org.apache.flink.connector.kinesis.source.split.StartingPosition;
+import org.apache.flink.connector.kinesis.source.util.TestUtil;
+
+import software.amazon.awssdk.core.SdkBytes;
+import software.amazon.awssdk.services.kinesis.model.Record;
+import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
+
+import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Utility class for Kinesis tests.
+ */
+public class FanOutKinesisTestUtils {
+
+ /**
+ * Creates a test Record with the given data.
+ *
+ * @param data The data to include in the record
+ * @return A test Record
+ */
+ public static Record createTestRecord(String data) {
+ return Record.builder()
+ .data(SdkBytes.fromString(data, StandardCharsets.UTF_8))
+ .approximateArrivalTimestamp(Instant.now())
+ .partitionKey("partitionKey")
+ .sequenceNumber("sequenceNumber")
+ .build();
+ }
+
+ /**
+ * Creates a test SubscribeToShardEvent with the given continuation sequence number and records.
+ *
+ * @param continuationSequenceNumber The continuation sequence number
+ * @param records The records to include in the event
+ * @return A test SubscribeToShardEvent
+ */
+ public static SubscribeToShardEvent createTestEvent(String continuationSequenceNumber, List records) {
+ return SubscribeToShardEvent.builder()
+ .continuationSequenceNumber(continuationSequenceNumber)
+ .millisBehindLatest(0L)
+ .records(records)
+ .build();
+ }
+
+ /**
+ * Creates a test SubscribeToShardEvent with the given continuation sequence number, records, and millisBehindLatest.
+ *
+ * @param continuationSequenceNumber The continuation sequence number
+ * @param records The records to include in the event
+ * @param millisBehindLatest The milliseconds behind latest
+ * @return A test SubscribeToShardEvent
+ */
+ public static SubscribeToShardEvent createTestEvent(
+ String continuationSequenceNumber, List records, long millisBehindLatest) {
+ return SubscribeToShardEvent.builder()
+ .continuationSequenceNumber(continuationSequenceNumber)
+ .millisBehindLatest(millisBehindLatest)
+ .records(records)
+ .build();
+ }
+
+ /**
+ * Creates a test KinesisShardSplit.
+ *
+ * @param streamArn The stream ARN
+ * @param shardId The shard ID
+ * @param startingPosition The starting position
+ * @return A test KinesisShardSplit
+ */
+ public static KinesisShardSplit createTestSplit(
+ String streamArn, String shardId, StartingPosition startingPosition) {
+ return new KinesisShardSplit(
+ streamArn,
+ shardId,
+ startingPosition,
+ Collections.emptySet(),
+ TestUtil.STARTING_HASH_KEY_TEST_VALUE,
+ TestUtil.ENDING_HASH_KEY_TEST_VALUE);
+ }
+
+ /**
+ * Gets the subscription for a specific shard from the reader using reflection.
+ *
+ * @param reader The reader
+ * @param shardId The shard ID
+ * @return The subscription
+ * @throws Exception If an error occurs
+ */
+ public static FanOutKinesisShardSubscription getSubscriptionFromReader(
+ FanOutKinesisShardSplitReader reader, String shardId) throws Exception {
+ // Get access to the subscriptions map
+ java.lang.reflect.Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("splitSubscriptions");
+ field.setAccessible(true);
+ Map subscriptions =
+ (Map) field.get(reader);
+ return subscriptions.get(shardId);
+ }
+
+ /**
+ * Sets the starting position in a subscription using reflection.
+ *
+ * @param subscription The subscription
+ * @param startingPosition The starting position
+ * @throws Exception If an error occurs
+ */
+ public static void setStartingPositionInSubscription(
+ FanOutKinesisShardSubscription subscription, StartingPosition startingPosition) throws Exception {
+ // Get access to the startingPosition field
+ java.lang.reflect.Field field = subscription.getClass().getDeclaredField("startingPosition");
+ field.setAccessible(true);
+ field.set(subscription, startingPosition);
+ }
+}