From d34cc676ff839b69e804c438b9855df73bc15a94 Mon Sep 17 00:00:00 2001 From: agrawal-siddharth Date: Sun, 7 Jun 2026 00:09:25 +0000 Subject: [PATCH] feat: scale up connection worker pool based on latency --- .../bigquery/storage/v1/ConnectionWorker.java | 57 +++++++-- .../storage/v1/ConnectionWorkerPoolTest.java | 50 ++++++++ .../storage/v1/ConnectionWorkerTest.java | 110 ++++++++++++++++-- 3 files changed, 195 insertions(+), 22 deletions(-) diff --git a/java-bigquerystorage/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java b/java-bigquerystorage/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java index 215176e7b46f..14eef0fda569 100644 --- a/java-bigquerystorage/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java +++ b/java-bigquerystorage/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorker.java @@ -1881,12 +1881,25 @@ void setRequestSendQueueTime() { /** Returns the current workload of this worker. */ public Load getLoad() { - return Load.create( - inflightBytes, - inflightRequests, - destinationSet.size(), - maxInflightBytes, - maxInflightRequests); + this.lock.lock(); + try { + Duration timeSinceLastCallback = Duration.ZERO; + if (!inflightRequestQueue.isEmpty()) { + AppendRequestAndResponse head = inflightRequestQueue.peekFirst(); + if (head != null && head.requestSendTimeStamp != null) { + timeSinceLastCallback = Duration.between(head.requestSendTimeStamp, Instant.now()); + } + } + return Load.create( + timeSinceLastCallback, + inflightBytes, + inflightRequests, + destinationSet.size(), + maxInflightBytes, + maxInflightRequests); + } finally { + this.lock.unlock(); + } } /** @@ -1896,11 +1909,15 @@ public Load getLoad() { @AutoValue public abstract static class Load { - // Consider the load on this worker to be overwhelmed when above some percentage of - // in-flight bytes or in-flight requests count. + // Consider the load on this worker to be overwhelmed when above some inflight latency or + // percentage of in-flight bytes or in-flight requests count. + private static Duration overwhelmedTimeSinceLastCallback = Duration.ofSeconds(3); private static double overwhelmedInflightCount = 0.2; private static double overwhelmedInflightBytes = 0.2; + // Time we have spent waiting for a response in the worker. + abstract Duration timeSinceLastCallback(); + // Number of in-flight requests bytes in the worker. abstract long inFlightRequestsBytes(); @@ -1917,12 +1934,14 @@ public abstract static class Load { abstract long maxInflightCount(); static Load create( + Duration timeSinceLastCallback, long inFlightRequestsBytes, long inFlightRequestsCount, long destinationCount, long maxInflightBytes, long maxInflightCount) { return new AutoValue_ConnectionWorker_Load( + timeSinceLastCallback, inFlightRequestsBytes, inFlightRequestsCount, destinationCount, @@ -1934,20 +1953,29 @@ boolean isOverwhelmed() { // Consider only in flight bytes and count for now, as by experiment those two are the most // efficient and has great simplity. return inFlightRequestsCount() > overwhelmedInflightCount * maxInflightCount() - || inFlightRequestsBytes() > overwhelmedInflightBytes * maxInflightBytes(); + || inFlightRequestsBytes() > overwhelmedInflightBytes * maxInflightBytes() + || timeSinceLastCallback().compareTo(overwhelmedTimeSinceLastCallback) > 0; } - // Compares two different load. First compare in flight request bytes split by size 1024 bucket. + // Compares two different load. First compare the timeSinceLastCallback bucketed into 1 second + // intervals. + // Then compare in flight request bytes split by size 1024 bucket. // Then compare the inflight requests count. // Then compare destination count of the two connections. public static final Comparator LOAD_COMPARATOR = - Comparator.comparing((Load key) -> (int) (key.inFlightRequestsBytes() / 1024)) + Comparator.comparing((Load key) -> (int) (key.timeSinceLastCallback().toMillis() / 1000)) + .thenComparing((Load key) -> (int) (key.inFlightRequestsBytes() / 1024)) .thenComparing((Load key) -> (int) (key.inFlightRequestsCount() / 100)) .thenComparing(Load::destinationCount); // Compares two different load without bucket, used in smaller scale unit testing. + // First compare the timeSinceLastCallback. + // Then compare in flight request bytes. + // Then compare the inflight requests count. + // Then compare destination count of the two connections. public static final Comparator TEST_LOAD_COMPARATOR = - Comparator.comparing((Load key) -> (int) key.inFlightRequestsBytes()) + Comparator.comparing(Load::timeSinceLastCallback) + .thenComparing((Load key) -> (int) key.inFlightRequestsBytes()) .thenComparing((Load key) -> (int) key.inFlightRequestsCount()) .thenComparing(Load::destinationCount); @@ -1960,6 +1988,11 @@ public static void setOverwhelmedBytesThreshold(double newThreshold) { public static void setOverwhelmedCountsThreshold(double newThreshold) { overwhelmedInflightCount = newThreshold; } + + @VisibleForTesting + public static void setOverwhelmedTimeSinceLastCallbackThreshold(Duration newThreshold) { + overwhelmedTimeSinceLastCallback = newThreshold; + } } @VisibleForTesting diff --git a/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java b/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java index 51fea1232b11..4409ae18c1e9 100644 --- a/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java +++ b/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java @@ -89,6 +89,7 @@ void setUp() throws Exception { .build(); ConnectionWorker.Load.setOverwhelmedCountsThreshold(0.5); ConnectionWorker.Load.setOverwhelmedBytesThreshold(0.6); + ConnectionWorker.Load.setOverwhelmedTimeSinceLastCallbackThreshold(Duration.ofSeconds(3)); } @Test @@ -555,6 +556,55 @@ private ProtoRows createProtoRows(String[] messages) { return rowsBuilder.build(); } + @Test + void testSingleTableConnections_overwhelmed_timeSinceLastCallback() throws Exception { + // Set count/bytes thresholds to be very high so they don't trigger. + ConnectionWorker.Load.setOverwhelmedCountsThreshold(0.9); + ConnectionWorker.Load.setOverwhelmedBytesThreshold(0.9); + // Set time threshold to 100ms. + ConnectionWorker.Load.setOverwhelmedTimeSinceLastCallbackThreshold(Duration.ofMillis(100)); + + // We use a pool with max 8 connections. + ConnectionWorkerPool.setOptions( + Settings.builder() + .setMinConnectionsPerRegion(1) // Start with 1 connection to make scaling obvious. + .setMaxConnectionsPerRegion(8) + .build()); + + // We set maxRequests to a large value (100) so it's not overwhelmed by count (threshold 90). + ConnectionWorkerPool connectionWorkerPool = + createConnectionWorkerPool( + /* maxRequests= */ 100, /* maxBytes= */ 1000000, java.time.Duration.ofSeconds(5)); + + // Stuck requests for 500ms (larger than 100ms threshold). + testBigQueryWrite.setResponseSleep(Duration.ofMillis(500)); + + // Send 1 request. It will go to Connection 1. + testBigQueryWrite.addResponse(createAppendResponse(0)); + StreamWriter writer = getTestStreamWriter(TEST_STREAM_1); + + ApiFuture future1 = + sendFooStringTestMessage(writer, connectionWorkerPool, new String[] {"0"}, 0); + + // Wait 200ms. Request 1 is still in flight (needs 500ms). + // Connection 1 timeSinceLastCallback should be ~200ms > 100ms. + // So Connection 1 is now overwhelmed. + Thread.sleep(200); + + // Send Request 2. Since Connection 1 is overwhelmed, it should scale up and create Connection + // 2. + testBigQueryWrite.addResponse(createAppendResponse(1)); + ApiFuture future2 = + sendFooStringTestMessage(writer, connectionWorkerPool, new String[] {"1"}, 1); + + // Wait for both to finish. + future1.get(); + future2.get(); + + // Verify that we created 2 connections. + assertThat(connectionWorkerPool.getCreateConnectionCount()).isEqualTo(2); + } + ConnectionWorkerPool createConnectionWorkerPool( long maxRequests, long maxBytes, java.time.Duration maxRetryDuration) { ConnectionWorkerPool.enableTestingLogic(); diff --git a/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java b/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java index 44bb25105d12..d40026a1e8d1 100644 --- a/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java +++ b/java-bigquerystorage/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerTest.java @@ -94,6 +94,9 @@ void setUp() throws Exception { testBigQueryWrite = new FakeBigQueryWrite(); ConnectionWorker.setMaxInflightQueueWaitTime(300000); ConnectionWorker.setMaxInflightRequestWaitTime(Duration.ofMinutes(10)); + ConnectionWorker.Load.setOverwhelmedCountsThreshold(0.2); + ConnectionWorker.Load.setOverwhelmedBytesThreshold(0.2); + ConnectionWorker.Load.setOverwhelmedTimeSinceLastCallbackThreshold(Duration.ofSeconds(3)); serviceHelper = new MockServiceHelper( UUID.randomUUID().toString(), Arrays.asList(testBigQueryWrite)); @@ -865,29 +868,116 @@ void testLoadCompare_compareLoad() { // In flight bytes bucket is split as per 1024 requests per bucket. // When in flight bytes is in lower bucket, even destination count is higher and request count // is higher, the load is still smaller. - Load load1 = ConnectionWorker.Load.create(1000, 2000, 100, 1000, 10); - Load load2 = ConnectionWorker.Load.create(2000, 1000, 10, 1000, 10); + Load load1 = ConnectionWorker.Load.create(Duration.ZERO, 1000, 2000, 100, 1000, 10); + Load load2 = ConnectionWorker.Load.create(Duration.ZERO, 2000, 1000, 10, 1000, 10); assertThat(Load.LOAD_COMPARATOR.compare(load1, load2)).isLessThan(0); // In flight bytes in the same bucke of request bytes will compare request count. - Load load3 = ConnectionWorker.Load.create(1, 300, 10, 0, 10); - Load load4 = ConnectionWorker.Load.create(10, 1, 10, 0, 10); + Load load3 = ConnectionWorker.Load.create(Duration.ZERO, 1, 300, 10, 0, 10); + Load load4 = ConnectionWorker.Load.create(Duration.ZERO, 10, 1, 10, 0, 10); assertThat(Load.LOAD_COMPARATOR.compare(load3, load4)).isGreaterThan(0); // In flight request and bytes in the same bucket will compare the destination count. - Load load5 = ConnectionWorker.Load.create(200, 1, 10, 1000, 10); - Load load6 = ConnectionWorker.Load.create(100, 10, 10, 1000, 10); + Load load5 = ConnectionWorker.Load.create(Duration.ZERO, 200, 1, 10, 1000, 10); + Load load6 = ConnectionWorker.Load.create(Duration.ZERO, 100, 10, 10, 1000, 10); assertThat(Load.LOAD_COMPARATOR.compare(load5, load6) == 0).isTrue(); + + // timeSinceLastCallback has the highest priority. + // load7 has higher timeSinceLastCallback (2s -> bucket 2) but lower other parameters. + // load8 has lower timeSinceLastCallback (0s -> bucket 0) but higher other parameters. + Load load7 = ConnectionWorker.Load.create(Duration.ofSeconds(2), 0, 0, 0, 10, 10); + Load load8 = ConnectionWorker.Load.create(Duration.ZERO, 10000, 10000, 100, 10, 10); + assertThat(Load.LOAD_COMPARATOR.compare(load7, load8)).isGreaterThan(0); } @Test void testLoadIsOverWhelmed() { - // Only in flight request is considered in current overwhelmed calculation. - Load load1 = ConnectionWorker.Load.create(60, 10, 100, 90, 100); + // In-flight requests, bytes, and timeSinceLastCallback are considered in overwhelmed + // calculation. + + // Overwhelmed by request count + Load load1 = ConnectionWorker.Load.create(Duration.ZERO, 60, 10, 100, 90, 100); assertThat(load1.isOverwhelmed()).isTrue(); - Load load2 = ConnectionWorker.Load.create(1, 1, 100, 100, 100); - assertThat(load2.isOverwhelmed()).isFalse(); + // Not overwhelmed + Load load2 = ConnectionWorker.Load.create(Duration.ZERO, 1, 1, 100, 100, 100); + assertFalse(load2.isOverwhelmed()); + + // Under threshold (3s) for timeSinceLastCallback + Load load3 = ConnectionWorker.Load.create(Duration.ofSeconds(2), 0, 0, 0, 100, 100); + assertFalse(load3.isOverwhelmed()); + + // Over threshold (3s) for timeSinceLastCallback + Load load4 = ConnectionWorker.Load.create(Duration.ofSeconds(4), 0, 0, 0, 100, 100); + assertTrue(load4.isOverwhelmed()); + } + + @Test + void testGetLoad_timeSinceLastCallback() throws Exception { + ProtoSchema schema1 = createProtoSchema("foo"); + StreamWriter sw1 = + StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build(); + try (ConnectionWorker connectionWorker = + new ConnectionWorker( + TEST_STREAM_1, + null, + createProtoSchema("foo"), + 10, + 100000, + Duration.ofSeconds(100), + FlowController.LimitExceededBehavior.Block, + TEST_TRACE_ID, + null, + client.getSettings(), + retrySettings, + /* enableRequestProfiler= */ false, + /* enableOpenTelemetry= */ false, + /*isMultiplexing*/ false)) { + + // Initially empty, should be zero. + assertThat(connectionWorker.getLoad().timeSinceLastCallback()).isEqualTo(Duration.ZERO); + + // Keep response in flight + testBigQueryWrite.setResponseSleep(java.time.Duration.ofSeconds(5)); + + // Send a message + ApiFuture future = + sendTestMessage(connectionWorker, sw1, createFooProtoRows(new String[] {"hello"}), 0); + + // Wait a bit to ensure it is sent and in flight queue + Thread.sleep(500); + + Load load = connectionWorker.getLoad(); + assertThat(load.timeSinceLastCallback()).isGreaterThan(Duration.ZERO); + assertThat(load.timeSinceLastCallback()) + .isLessThan(Duration.ofSeconds(2)); // Should be around 500ms + } + } + + @Test + void testLoadCompare_timeSinceLastCallback() { + // Same bytes, same count, same destination, different timeSinceLastCallback + // Bucketed by 1 second (1000ms). + + // 100ms and 200ms are in the same bucket (0). + Load load1 = ConnectionWorker.Load.create(Duration.ofMillis(100), 0, 0, 0, 0, 0); + Load load2 = ConnectionWorker.Load.create(Duration.ofMillis(200), 0, 0, 0, 0, 0); + assertThat(Load.LOAD_COMPARATOR.compare(load1, load2)).isEqualTo(0); + + // 100ms and 1200ms are in different buckets (0 vs 1). + Load load3 = ConnectionWorker.Load.create(Duration.ofMillis(1200), 0, 0, 0, 0, 0); + assertThat(Load.LOAD_COMPARATOR.compare(load1, load3)).isLessThan(0); + assertThat(Load.LOAD_COMPARATOR.compare(load3, load1)).isGreaterThan(0); + } + + @Test + void testTestLoadCompare_timeSinceLastCallback() { + // TEST_LOAD_COMPARATOR compares timeSinceLastCallback unbucketed. + // 1s and 2s should be different. + Load load1 = ConnectionWorker.Load.create(Duration.ofSeconds(1), 0, 0, 0, 0, 0); + Load load2 = ConnectionWorker.Load.create(Duration.ofSeconds(2), 0, 0, 0, 0, 0); + assertThat(Load.TEST_LOAD_COMPARATOR.compare(load1, load2)).isLessThan(0); + assertThat(Load.TEST_LOAD_COMPARATOR.compare(load2, load1)).isGreaterThan(0); } @Test