diff --git a/driver-core/src/main/com/mongodb/internal/time/ExponentialBackoff.java b/driver-core/src/main/com/mongodb/internal/time/ExponentialBackoff.java new file mode 100644 index 0000000000..ed9bba51d7 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/time/ExponentialBackoff.java @@ -0,0 +1,86 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.time; + +import com.mongodb.internal.VisibleForTesting; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.DoubleSupplier; + +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; + +/** + * Implements exponential backoff with jitter for retry scenarios. + */ +public enum ExponentialBackoff { + TRANSACTION(5.0, 500.0, 1.5); + + private final double baseMs, maxMs, growth; + + // TODO remove this global state once https://jira.mongodb.org/browse/JAVA-6060 is done + private static DoubleSupplier testJitterSupplier = null; + + ExponentialBackoff(final double baseMs, final double maxMs, final double growth) { + this.baseMs = baseMs; + this.maxMs = maxMs; + this.growth = growth; + } + + /** + * Calculate the next delay in milliseconds based on the retry count. + * + * @param retryCount The number of retries that have occurred. + * @return The calculated delay in milliseconds. + */ + public long calculateDelayBeforeNextRetryMs(final int retryCount) { + double jitter = testJitterSupplier != null + ? testJitterSupplier.getAsDouble() + : ThreadLocalRandom.current().nextDouble(); + double backoff = Math.min(baseMs * Math.pow(growth, retryCount), maxMs); + return Math.round(jitter * backoff); + } + + /** + * Calculate the next delay in milliseconds based on the retry count and a provided jitter. + * + * @param retryCount The number of retries that have occurred. + * @param jitter A double in the range [0, 1) to apply as jitter. + * @return The calculated delay in milliseconds. + */ + public long calculateDelayBeforeNextRetryMs(final int retryCount, final double jitter) { + double backoff = Math.min(baseMs * Math.pow(growth, retryCount), maxMs); + return Math.round(jitter * backoff); + } + + /** + * Set a custom jitter supplier for testing purposes. + * + * @param supplier A DoubleSupplier that returns values in [0, 1) range. + */ + @VisibleForTesting(otherwise = PRIVATE) + public static void setTestJitterSupplier(final DoubleSupplier supplier) { + testJitterSupplier = supplier; + } + + /** + * Clear the test jitter supplier, reverting to default ThreadLocalRandom behavior. + */ + @VisibleForTesting(otherwise = PRIVATE) + public static void clearTestJitterSupplier() { + testJitterSupplier = null; + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/ExponentialBackoffTest.java b/driver-core/src/test/unit/com/mongodb/internal/ExponentialBackoffTest.java new file mode 100644 index 0000000000..6723853248 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/ExponentialBackoffTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal; + +import com.mongodb.internal.time.ExponentialBackoff; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ExponentialBackoffTest { + + @Test + void testTransactionRetryBackoff() { + // Test that the backoff sequence follows the expected pattern with growth factor 1.5 + // Expected sequence (without jitter): 5, 7.5, 11.25, ... + // With jitter, actual values will be between 0 and these maxima + double[] expectedMaxValues = {5.0, 7.5, 11.25, 16.875, 25.3125, 37.96875, 56.953125, 85.4296875, 128.14453125, 192.21679688, 288.32519531, 432.48779297, 500.0}; + + ExponentialBackoff backoff = ExponentialBackoff.TRANSACTION; + for (int retry = 0; retry < expectedMaxValues.length; retry++) { + long delay = backoff.calculateDelayBeforeNextRetryMs(retry); + assertTrue(delay >= 0 && delay <= Math.round(expectedMaxValues[retry]), String.format("Retry %d: delay should be 0-%d ms, got: %d", retry, Math.round(expectedMaxValues[retry]), delay)); + } + } + + @Test + void testTransactionRetryBackoffRespectsMaximum() { + ExponentialBackoff backoff = ExponentialBackoff.TRANSACTION; + + // Even at high retry counts, delay should never exceed 500ms + for (int retry = 0; retry < 25; retry++) { + long delay = backoff.calculateDelayBeforeNextRetryMs(retry); + assertTrue(delay >= 0 && delay <= 500, String.format("Retry %d: delay should be capped at 500 ms, got: %d ms", retry, delay)); + } + } + + @Test + void testCustomJitter() { + ExponentialBackoff backoff = ExponentialBackoff.TRANSACTION; + + // Expected delays with jitter=1.0 and growth factor 1.5 + double[] expectedDelays = {5.0, 7.5, 11.25, 16.875, 25.3125, 37.96875, 56.953125, 85.4296875, 128.14453125, 192.21679688, 288.32519531, 432.48779297, 500.0}; + double jitter = 1.0; + + for (int retry = 0; retry < expectedDelays.length; retry++) { + long delay = backoff.calculateDelayBeforeNextRetryMs(retry, jitter); + long expected = Math.round(expectedDelays[retry]); + assertEquals(expected, delay, String.format("Retry %d: with jitter=1.0, delay should be %d ms", retry, expected)); + } + + // With jitter = 0, all delays should be 0 + jitter = 0; + for (int retry = 0; retry < 10; retry++) { + long delay = backoff.calculateDelayBeforeNextRetryMs(retry, jitter); + assertEquals(0, delay, "With jitter=0, delay should always be 0 ms"); + } + } +} diff --git a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionImpl.java b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionImpl.java index aa1414dce5..fcaea52aaa 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionImpl.java +++ b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionImpl.java @@ -28,6 +28,8 @@ import com.mongodb.client.ClientSession; import com.mongodb.client.TransactionBody; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.observability.micrometer.TracingManager; +import com.mongodb.internal.observability.micrometer.TransactionSpan; import com.mongodb.internal.operation.AbortTransactionOperation; import com.mongodb.internal.operation.CommitTransactionOperation; import com.mongodb.internal.operation.OperationHelper; @@ -36,8 +38,7 @@ import com.mongodb.internal.operation.WriteOperation; import com.mongodb.internal.session.BaseClientSessionImpl; import com.mongodb.internal.session.ServerSessionPool; -import com.mongodb.internal.observability.micrometer.TracingManager; -import com.mongodb.internal.observability.micrometer.TransactionSpan; +import com.mongodb.internal.time.ExponentialBackoff; import com.mongodb.lang.Nullable; import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL; @@ -46,6 +47,7 @@ import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; final class ClientSessionImpl extends BaseClientSessionImpl implements ClientSession { @@ -251,13 +253,21 @@ public T withTransaction(final TransactionBody transactionBody, final Tra notNull("transactionBody", transactionBody); long startTime = ClientSessionClock.INSTANCE.now(); TimeoutContext withTransactionTimeoutContext = createTimeoutContext(options); + ExponentialBackoff transactionBackoff = ExponentialBackoff.TRANSACTION; + int transactionAttempt = 0; + MongoException lastError = null; try { outer: while (true) { + if (transactionAttempt > 0) { + backoff(transactionBackoff, transactionAttempt, startTime, lastError); + } T retVal; try { startTransaction(options, withTransactionTimeoutContext.copyTimeoutContext()); + transactionAttempt++; + if (transactionSpan != null) { transactionSpan.setIsConvenientTransaction(); } @@ -266,14 +276,17 @@ public T withTransaction(final TransactionBody transactionBody, final Tra if (transactionState == TransactionState.IN) { abortTransaction(); } - if (e instanceof MongoException && !(e instanceof MongoOperationTimeoutException)) { - MongoException exceptionToHandle = OperationHelper.unwrap((MongoException) e); - if (exceptionToHandle.hasErrorLabel(TRANSIENT_TRANSACTION_ERROR_LABEL) - && ClientSessionClock.INSTANCE.now() - startTime < MAX_RETRY_TIME_LIMIT_MS) { - if (transactionSpan != null) { - transactionSpan.spanFinalizing(false); + if (e instanceof MongoException) { + lastError = (MongoException) e; // Store last error + if (!(e instanceof MongoOperationTimeoutException)) { + MongoException exceptionToHandle = OperationHelper.unwrap((MongoException) e); + if (exceptionToHandle.hasErrorLabel(TRANSIENT_TRANSACTION_ERROR_LABEL) + && ClientSessionClock.INSTANCE.now() - startTime < MAX_RETRY_TIME_LIMIT_MS) { + if (transactionSpan != null) { + transactionSpan.spanFinalizing(false); + } + continue; } - continue; } } throw e; @@ -296,6 +309,7 @@ public T withTransaction(final TransactionBody transactionBody, final Tra if (transactionSpan != null) { transactionSpan.spanFinalizing(true); } + lastError = e; continue outer; } } @@ -359,4 +373,22 @@ private TimeoutContext createTimeoutContext(final TransactionOptions transaction TransactionOptions.merge(transactionOptions, getOptions().getDefaultTransactionOptions()), operationExecutor.getTimeoutSettings())); } + + private static void backoff(final ExponentialBackoff exponentialBackoff, final int transactionAttempt, final long startTime, + final MongoException lastError) { + long backoffMs = exponentialBackoff.calculateDelayBeforeNextRetryMs(transactionAttempt - 1); + if (ClientSessionClock.INSTANCE.now() + backoffMs - startTime >= MAX_RETRY_TIME_LIMIT_MS) { + if (lastError != null) { + throw lastError; + } + throw new MongoClientException("Transaction retry timeout exceeded"); + } + try { + if (backoffMs > 0) { + Thread.sleep(backoffMs); + } + } catch (InterruptedException e) { + throw interruptAndCreateMongoInterruptedException("Transaction retry interrupted", e); + } + } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/WithTransactionProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/WithTransactionProseTest.java index 1afbf61565..e2dce11583 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/WithTransactionProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/WithTransactionProseTest.java @@ -22,15 +22,20 @@ import com.mongodb.TransactionOptions; import com.mongodb.client.internal.ClientSessionClock; import com.mongodb.client.model.Sorts; +import com.mongodb.internal.time.ExponentialBackoff; +import org.bson.BsonDocument; import org.bson.Document; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static com.mongodb.ClusterFixture.TIMEOUT; import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet; import static com.mongodb.ClusterFixture.isSharded; +import static com.mongodb.client.Fixture.getPrimary; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -203,6 +208,76 @@ public void testTimeoutMSAndLegacySettings() { } } + /** + * See + * Convenient API Prose Tests. + */ + @DisplayName("Retry Backoff is Enforced") + @Test + public void testRetryBackoffIsEnforced() throws InterruptedException { + // Run with jitter = 0 (no backoff) + ExponentialBackoff.setTestJitterSupplier(() -> 0.0); + + BsonDocument failPointDocument = BsonDocument.parse("{'configureFailPoint': 'failCommand', 'mode': {'times': 13}, " + + "'data': {'failCommands': ['commitTransaction'], 'errorCode': 251}}"); + + long noBackoffTime; + try (ClientSession session = client.startSession(); + FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { + long startNanos = System.nanoTime(); + session.withTransaction(() -> collection.insertOne(session, Document.parse("{}"))); + noBackoffTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + } finally { + // Clear the test jitter supplier to avoid affecting other tests + ExponentialBackoff.clearTestJitterSupplier(); + } + + // Run with jitter = 1 (full backoff) + ExponentialBackoff.setTestJitterSupplier(() -> 1.0); + + failPointDocument = BsonDocument.parse("{'configureFailPoint': 'failCommand', 'mode': {'times': 13}, " + + "'data': {'failCommands': ['commitTransaction'], 'errorCode': 251}}"); + + long withBackoffTime; + try (ClientSession session = client.startSession(); + FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { + long startNanos = System.nanoTime(); + session.withTransaction(() -> collection.insertOne(session, Document.parse("{}"))); + withBackoffTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos); + } finally { + ExponentialBackoff.clearTestJitterSupplier(); + } + + long expectedWithBackoffTime = noBackoffTime + 1800; + long actualDifference = Math.abs(withBackoffTime - expectedWithBackoffTime); + + assertTrue(actualDifference < 1000, String.format("Expected withBackoffTime to be ~% dms (noBackoffTime %d ms + 1800 ms), but" + + " got %d ms. Difference: %d ms (tolerance: 1000 ms per spec)", expectedWithBackoffTime, noBackoffTime, withBackoffTime, + actualDifference)); + } + + /** + * This test is not from the specification. + */ + @Test + public void testExponentialBackoffOnTransientError() throws InterruptedException { + BsonDocument failPointDocument = BsonDocument.parse("{'configureFailPoint': 'failCommand', 'mode': {'times': 3}, " + + "'data': {'failCommands': ['insert'], 'errorCode': 112, " + + "'errorLabels': ['TransientTransactionError']}}"); + + try (ClientSession session = client.startSession(); + FailPoint ignored = FailPoint.enable(failPointDocument, getPrimary())) { + AtomicInteger attemptsCount = new AtomicInteger(0); + + session.withTransaction(() -> { + attemptsCount.incrementAndGet(); // Count the attempt before the operation that might fail + return collection.insertOne(session, Document.parse("{}")); + }); + + assertEquals(4, attemptsCount.get(), "Expected 1 initial attempt + 3 retries"); + } + } + private boolean canRunTests() { return isSharded() || isDiscoverableReplicaSet(); }