diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..83ca4e9c7 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +DB_CONNECTION_STRING=jdbc:aws-wrapper:postgresql://localhost:5432/dbname +CACHE_RW_SERVER_ADDR=localhost:6379 +CACHE_RO_SERVER_ADDR=localhost:6380 +DB_USERNAME=postgres +DB_PASSWORD=admin diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 359765fe5..09c2809ae 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -25,6 +25,10 @@ dependencies { implementation("org.mariadb.jdbc:mariadb-java-client:3.5.6") implementation("com.zaxxer:HikariCP:4.0.3") implementation("org.checkerframework:checker-qual:3.49.5") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") + annotationProcessor("org.openjdk.jmh:jmh-core:1.36") + jmhAnnotationProcessor ("org.openjdk.jmh:jmh-generator-annprocess:1.36") testImplementation("org.junit.jupiter:junit-jupiter-api:5.12.2") testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java new file mode 100644 index 000000000..8c18d4c44 --- /dev/null +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java @@ -0,0 +1,144 @@ +package software.amazon.jdbc.benchmarks; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import java.sql.*; +import java.util.*; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.profile.GCProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Performance benchmark program against PG. + * + * This test program runs JMH benchmark tests the performance of the remote cache plugin against a + * a remote PG database and a remote cache server for both indexed queries and non-indexed queries. + * + * The database table schema is as follows: + * + * postgres=# CREATE TABLE test (id SERIAL PRIMARY KEY, int_col INTEGER, varchar_col varchar(50) NOT NULL, text_col TEXT, + * num_col DOUBLE PRECISION, date_col date, time_col TIME WITHOUT TIME ZONE, time_tz TIME WITH TIME ZONE, + * ts_col TIMESTAMP WITHOUT TIME ZONE, ts_tz TIMESTAMP WITH TIME ZONE, description TEXT); + * CREATE TABLE + * postgres=# select * from test; + * id | int_col | varchar_col | text_col | num_col | date_col | time_col | time_tz | ts_col | ts_tz | description + * ----+---------+-------------+----------+---------+----------+----------+---------+--------+-------+-------------- + * (0 rows) + * + */ +@State(Scope.Thread) +@Fork(1) +@Warmup(iterations = 1) +@Measurement(iterations = 60, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class PgCacheBenchmarks { + private static final String DB_CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-0.XYZ.us-east-2.rds.amazonaws.com:5432/postgres"; + private static final String CACHE_RW_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6379"; + private static final String CACHE_RO_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6380"; + + private Connection connection; + private int counter; + long startTime; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(PgCacheBenchmarks.class.getSimpleName()) + .addProfiler(GCProfiler.class) + .detectJvmArgs() + .build(); + + new Runner(opt).run(); + } + + @Setup(Level.Trial) + public void setup() throws SQLException { + try { + software.amazon.jdbc.Driver.register(); + } catch (IllegalStateException e) { + System.out.println("exception during register() is " + e.getMessage()); + } + Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + properties.setProperty("wrapperLogUnclosedConnections", "true"); + counter = 0; + connection = DriverManager.getConnection(DB_CONNECTION_STRING, properties); + startTime = System.currentTimeMillis(); + } + + @TearDown(Level.Trial) + public void tearDown() throws SQLException { + connection.close(); + } + + // Code to warm up the data in the table + public void warmUpDataSet() throws SQLException { + String desc_1KB = "mP48pHrR5vreBo3N6ecmlDgvfEAz0kQEOUQ89U3Rh05BTG9LhB8R0HBFBp5RIqc8vVcrphu89kW1OE2c2xApwpczFMdDAuk2SxOl9OrLvfk9zGYrdfzedcepT8LVeE6NTtYDeP3yo6UFC6AiOeqRBY5NEaNcZ8fuoXVpqOrqAhz910v5XrFxeXUyPDFxuaKFLaHfEFq7BRasUc9nfhP8gblKAGfEEmgYBpUKio27Rfo0xnavfVJQkAA2kME2PT4qZRSqeDkLmn7VBAzT9ghHqe9D4kQLQKjIyIPKqYoS8kW3ShW44VqYENwPSRAXw7UqOJqlKJ4pnmx4sPZO2kI4NYOl1JZXNlbGaSzJR0cOloKiY0z2OmUNvmD0Wju1DC9TT4OY6a6DOfFvk265BfDVxT6ufN68YG9sZuVsl7jq8SZSJg3x2cqlJuAtdSTIoKmJT1a6cEXxVusmdO27kRRp1BfWR4gz4w9HawYf9nBQOq76ObctlNvj0fYUUG3I49s3iP33CL8qZjj9RnyNUus6ieiZgta6L3mZuMRYOgCLyJrAKUYEL9KND7qirCPzVgmJHWIOnVewu8mldYFhroL89yvV3bZx4MGeyPU4KvbCsRgdORCTN0XhuLYUdiehHXnDBfuZ5yyR0saWLh8gjkLV5GkxTeKpOhpoK1o1cMiCDPYqTa64g5JundlW707c9zxc3Xnf2pW7E74YJl5oBu5vWEyPqXtYOtZOjOIRxxDY8QpoW8mpbQXxgB8DjkZZMiUCe0qHZYxvktVZJmHoaYBwpYpXVTZCfq9WajmkIOdIad1VnH5HpaECLRs6loa259yH8qesak2feDiKjfb8p3uj3s7WZUvPJwAWX9PIW1p7x6OiszXQCntOFRC3bQFNz1c98wlCBJnBSxbbYhU057TDNnoaib1h9bH7LAcqD1caE5KwLMAc5HqugkkRzT5NszkdJcpF0SxakdrAQLOKS6sNwDUzBJA76F775vmaqe3XIYecPmGtfoAKMychfEI4vfNr"; + for (int i = 0; i < 400000; i++) { + Statement stmt = connection.createStatement(); + String description = "description " + i; + String text = "here is my text data " + i; + String query = "insert into test values (" + i + ", " + i * 10 + ", '" + description + "', '" + text + "', " + i * 100 + 0.1234 + ", '2024-01-10', '10:00:00', '10:00:00-07', '2025-07-15 10:00:00', '2025-07-15 10:00:00-07'" + ", '" + desc_1KB + "');"; + int rs = stmt.executeUpdate(query); + assert rs == 1; + } + } + + private void validateResultSet(ResultSet rs, Blackhole b) throws SQLException { + while (rs.next()) { + b.consume(rs.getInt(1)); + b.consume(rs.getInt(2)); + b.consume(rs.getString(3)); + b.consume(rs.getString(4)); + b.consume(rs.getDouble(5)); + b.consume(rs.getDate(6)); + b.consume(rs.getTime(7)); + b.consume(rs.getTime(8)); + b.consume(rs.getTimestamp(9)); + b.consume(rs.getTimestamp(10)); + b.consume(rs.wasNull()); + } + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where id = " + counter)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookupWithCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where id = " + counter)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookupWithCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs, b); + } + counter++; + } +} diff --git a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md index dea51924f..b8069fc24 100644 --- a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md +++ b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md @@ -220,7 +220,7 @@ The AWS JDBC Driver has several built-in plugins that are available to use. Plea [^2]: Federated Identity and Okta rely on IAM. Due to [^1], RDS Multi-AZ Clusters are not supported. > [!NOTE]\ -> To see information logged by plugins such as `DataCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. +> To see information logged by plugins such as `DataLocalCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. In addition to the built-in plugins, you can also create custom plugins more suitable for your needs. For more information, see [Custom Plugins](../development-guide/LoadablePlugins.md#using-custom-plugins). diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index 08fee3f19..ae4b9ab61 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -29,6 +29,8 @@ dependencies { implementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") implementation("org.jsoup:jsoup:1.21.1") implementation("com.mchange:c3p0:0.11.0") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") } tasks.withType { diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java new file mode 100644 index 000000000..d20620138 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -0,0 +1,186 @@ +package software.amazon; + +import software.amazon.util.EnvLoader; +import java.sql.*; +import java.util.*; +import java.util.logging.Logger; + +public class DatabaseConnectionWithCacheExample { + + private static final EnvLoader env = new EnvLoader(); + + private static final String DB_CONNECTION_STRING = env.get("DB_CONNECTION_STRING"); + private static final String CACHE_RW_SERVER_ADDR = env.get("CACHE_RW_SERVER_ADDR"); + private static final String CACHE_RO_SERVER_ADDR = env.get("CACHE_RO_SERVER_ADDR"); + // If the cache server is authenticated with IAM + private static final String CACHE_NAME = env.get("CACHE_NAME"); + // Both IAM and traditional auth uses the same CACHE_USERNAME + private static final String CACHE_USERNAME = env.get("CACHE_USERNAME"); // e.g., "iam-user-01" / "username" + private static final String CACHE_IAM_REGION = env.get("CACHE_IAM_REGION"); // e.g., "us-west-2" + private static final String CACHE_USE_SSL = env.get("CACHE_USE_SSL"); + // If the cache server is authenticated with traditional username password + // private static final String CACHE_PASSWORD = env.get("CACHE_PASSWORD"); + private static final String USERNAME = env.get("DB_USERNAME"); + private static final String PASSWORD = env.get("DB_PASSWORD"); + private static final int THREAD_COUNT = 8; //Use 8 Threads + private static final long TEST_DURATION_MS = 16000; //Test duration for 16 seconds + private static final String CACHE_CONNECTION_TIMEOUT = env.get("CACHE_CONNECTION_TIMEOUT"); //Set connection timeout in milliseconds + private static final String CACHE_CONNECTION_POOL_SIZE = env.get("CACHE_CONNECTION_POOL_SIZE"); //Set connection pool size + // Failure handling configurations + private static final String FAIL_WHEN_CACHE_DOWN = env.get("FAIL_WHEN_CACHE_DOWN"); + private static final String CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT = env.get("CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT"); + private static final String CACHE_HEALTH_CHECK_IN_HEALTHY_STATE = env.get("CACHE_HEALTH_CHECK_IN_HEALTHY_STATE"); + + // If multi endpoint is configured + private static final String CACHE_RW_SERVER_ADDR2 = env.get("CACHE_RW_SERVER_ADDR2"); + private static final String CACHE_RO_SERVER_ADDR2 = env.get("CACHE_RO_SERVER_ADDR2"); + private static final String CACHE_NAME2 = env.get("CACHE_NAME2"); + // Both IAM and traditional auth uses the same CACHE_USERNAME + private static final String CACHE_USERNAME2 = env.get("CACHE_USERNAME2"); // e.g., "iam-user-01" / "username" + private static final String CACHE_IAM_REGION2 = env.get("CACHE_IAM_REGION2"); + + public static void main(String[] args) throws SQLException { + final Properties properties = new Properties(); + final Logger LOGGER = Logger.getLogger(DatabaseConnectionWithCacheExample.class.getName()); + + // Configuring connection properties for the underlying JDBC driver. + properties.setProperty("user", USERNAME); + properties.setProperty("password", PASSWORD); + + // Configuring connection properties for the JDBC Wrapper. + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + // If the cache server is authenticated with IAM + properties.setProperty("cacheName", CACHE_NAME); + properties.setProperty("cacheUsername", CACHE_USERNAME); + properties.setProperty("cacheIamRegion", CACHE_IAM_REGION); + // If the cache server is authenticated with traditional username password + // properties.setProperty("cachePassword", CACHE_PASSWORD); + properties.setProperty("cacheUseSSL", CACHE_USE_SSL); // "true" or "false" + properties.setProperty("wrapperLogUnclosedConnections", "true"); + properties.setProperty("cacheConnectionTimeout", CACHE_CONNECTION_TIMEOUT); + properties.setProperty("cacheConnectionPoolSize", CACHE_CONNECTION_POOL_SIZE); + properties.setProperty("failWhenCacheDown", FAIL_WHEN_CACHE_DOWN); + properties.setProperty("cacheInFlightWriteSizeLimitBytes", CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT); + properties.setProperty("cacheHealthCheckInHealthyState", CACHE_HEALTH_CHECK_IN_HEALTHY_STATE); + + String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; + + // Create threads for concurrent connection testing + Thread[] threads = new Thread[THREAD_COUNT]; + for (int t = 0; t < THREAD_COUNT; t++) { + // Each thread uses a single connection for multiple queries + threads[t] = new Thread(() -> { + try { + try (Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, properties)) { + long endTime = System.currentTimeMillis() + TEST_DURATION_MS; + try (Statement stmt = conn.createStatement()) { + while (System.currentTimeMillis() < endTime) { + ResultSet rs = stmt.executeQuery(queryStr); + System.out.println("Executed the SQL query with result sets: " + rs.toString()); + } + } + } + } catch (Exception e) { + LOGGER.warning("Error: " + e.getMessage()); + } + }); + threads[t].start(); + } + // Wait for all threads to complete + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + LOGGER.warning("Thread interrupted: " + e.getMessage()); + } + } + + // multi cache endpoint example. + runMultiEndPointExample(); + } + + /* + * Multi cache Endpoint Example + * This example demonstrates how to use multiple cache endpoints. + * It creates two threads, each using a different cache endpoint. + * The cache endpoints are configured in the properties object. + * */ + public static void runMultiEndPointExample() throws SQLException { + final Logger LOGGER = Logger.getLogger(DatabaseConnectionWithCacheExample.class.getName()); + String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; + + Properties properties1 = new Properties(); + properties1.setProperty("user", USERNAME); + properties1.setProperty("password", PASSWORD); + properties1.setProperty("wrapperPlugins", "dataRemoteCache"); + properties1.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties1.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + properties1.setProperty("cacheUseSSL", CACHE_USE_SSL); + properties1.setProperty("wrapperLogUnclosedConnections", "true"); + properties1.setProperty("cacheConnectionTimeout", CACHE_CONNECTION_TIMEOUT); + properties1.setProperty("cacheConnectionPoolSize", CACHE_CONNECTION_POOL_SIZE); + // If the cache server is authenticated with IAM + properties1.setProperty("cacheName", CACHE_NAME); + properties1.setProperty("cacheUsername", CACHE_USERNAME); + properties1.setProperty("cacheIamRegion", CACHE_IAM_REGION); + // If the cache server is authenticated with traditional username password + // properties.setProperty("cachePassword", CACHE_PASSWORD); + properties1.setProperty("failWhenCacheDown", FAIL_WHEN_CACHE_DOWN); + properties1.setProperty("cacheInFlightWriteSizeLimitBytes", CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT); + properties1.setProperty("cacheHealthCheckInHealthyState", CACHE_HEALTH_CHECK_IN_HEALTHY_STATE); + + + Properties properties2 = new Properties(); + properties2.setProperty("user", USERNAME); + properties2.setProperty("password", PASSWORD); + properties2.setProperty("wrapperPlugins", "dataRemoteCache"); + properties2.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR2); + properties2.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR2); + properties2.setProperty("cacheUseSSL", CACHE_USE_SSL); + properties2.setProperty("wrapperLogUnclosedConnections", "true"); + properties2.setProperty("cacheConnectionTimeout", CACHE_CONNECTION_TIMEOUT); + properties2.setProperty("cacheConnectionPoolSize", CACHE_CONNECTION_POOL_SIZE); + // If the cache server is authenticated with IAM + properties2.setProperty("cacheName", CACHE_NAME2); + properties2.setProperty("cacheUsername", CACHE_USERNAME2); + properties2.setProperty("cacheIamRegion", CACHE_IAM_REGION2); + // If the cache server is authenticated with traditional username password + // properties.setProperty("cachePassword", CACHE_PASSWORD2); + properties2.setProperty("failWhenCacheDown", FAIL_WHEN_CACHE_DOWN); + properties2.setProperty("cacheInFlightWriteSizeLimitBytes", CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT); + properties2.setProperty("cacheHealthCheckInHealthyState", CACHE_HEALTH_CHECK_IN_HEALTHY_STATE); + + // Create threads with different cache endpoints + Thread[] threads = new Thread[THREAD_COUNT]; + + for (int t = 0; t < THREAD_COUNT; t++) { + final Properties threadProps = (t%2 == 0) ? properties1 : properties2; + final int threadNum = t + 1; + + threads[t] = new Thread(() -> { + try (Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, threadProps)) { + long endTime = System.currentTimeMillis() + TEST_DURATION_MS; + try (Statement stmt = conn.createStatement()) { + while (System.currentTimeMillis() < endTime) { + ResultSet rs = stmt.executeQuery(queryStr); + System.out.println("Thread " + threadNum + " executed SQL query with result sets: " + rs.toString()); + } + } + } catch (Exception e) { + LOGGER.warning("Thread " + threadNum + " error: " + e.getMessage()); + } + }); + threads[t].start(); + } + // Wait for all threads to complete + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + LOGGER.warning("Thread interrupted: " + e.getMessage()); + } + } + } +} diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java new file mode 100644 index 000000000..7b12d91f5 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java @@ -0,0 +1,83 @@ +package software.amazon.util; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +/** + * A simple utility class to load environment variables from a .env file. + */ +public class EnvLoader { + private final Map envVars = new HashMap<>(); + + /** + * Loads environment variables from a .env file in the current directory. + */ + public EnvLoader() { + this(Paths.get(".env")); + } + + /** + * Loads environment variables from the specified file path. + * + * @param envPath Path to the .env file + */ + public EnvLoader(Path envPath) { + if (Files.exists(envPath)) { + try (BufferedReader reader = new BufferedReader(new FileReader(envPath.toFile()))) { + String line; + while ((line = reader.readLine()) != null) { + parseLine(line); + } + } catch (IOException e) { + System.err.println("Error reading .env file: " + e.getMessage()); + } + } + } + + private void parseLine(String line) { + line = line.trim(); + // Skip empty lines and comments + if (line.isEmpty() || line.startsWith("#")) { + return; + } + + // Split on the first equals sign + int delimiterPos = line.indexOf('='); + if (delimiterPos > 0) { + String key = line.substring(0, delimiterPos).trim(); + String value = line.substring(delimiterPos + 1).trim(); + + // Remove quotes if present + if ((value.startsWith("\"") && value.endsWith("\"")) || + (value.startsWith("'") && value.endsWith("'"))) { + value = value.substring(1, value.length() - 1); + } + + envVars.put(key, value); + } + } + + /** + * Gets the value of an environment variable. + * + * @param key The name of the environment variable + * @return The value of the environment variable, or null if not found + */ + public String get(String key) { + // First check the loaded .env file + String value = envVars.get(key); + + // If not found, check system environment variables + if (value == null) { + value = System.getenv(key); + } + + return value; + } +} diff --git a/examples/AWSDriverExample/src/main/resources/logback.xml b/examples/AWSDriverExample/src/main/resources/logback.xml new file mode 100644 index 000000000..e03eaf554 --- /dev/null +++ b/examples/AWSDriverExample/src/main/resources/logback.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 0e6fb29f1..7de1e94d1 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -44,8 +44,10 @@ dependencies { optionalImplementation("com.mchange:c3p0:0.11.0") optionalImplementation("org.apache.httpcomponents:httpclient:4.5.14") optionalImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + optionalImplementation("org.apache.commons:commons-pool2:2.11.1") optionalImplementation("org.jsoup:jsoup:1.21.1") optionalImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + optionalImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") optionalImplementation("io.opentelemetry:opentelemetry-api:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") @@ -98,10 +100,12 @@ dependencies { testImplementation("org.slf4j:slf4j-simple:2.0.17") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("io.opentelemetry:opentelemetry-api:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.52.0") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.jsoup:jsoup:1.21.1") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 @@ -208,7 +212,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) @@ -223,7 +227,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 952b00936..ef07e611c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -33,7 +33,8 @@ import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPluginFactory; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPluginFactory; import software.amazon.jdbc.plugin.ConnectTimeConnectionPluginFactory; -import software.amazon.jdbc.plugin.DataCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePluginFactory; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPluginFactory; @@ -68,7 +69,8 @@ public class ConnectionPluginChainBuilder { { put("executionTime", new ExecutionTimeConnectionPluginFactory()); put("logQuery", new LogQueryConnectionPluginFactory()); - put("dataCache", new DataCacheConnectionPluginFactory()); + put("dataCache", new DataLocalCacheConnectionPluginFactory()); + put("dataRemoteCache", new DataRemoteCachePluginFactory()); put("customEndpoint", new CustomEndpointPluginFactory()); put("efm", new HostMonitoringConnectionPluginFactory()); put("efm2", new software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPluginFactory()); @@ -100,7 +102,8 @@ public class ConnectionPluginChainBuilder { new HashMap, Integer>() { { put(DriverMetaDataConnectionPluginFactory.class, 100); - put(DataCacheConnectionPluginFactory.class, 200); + put(DataLocalCacheConnectionPluginFactory.class, 200); + put(DataRemoteCachePluginFactory.class, 250); put(CustomEndpointPluginFactory.class, 380); put(AuroraInitialConnectionStrategyPluginFactory.class, 390); put(AuroraConnectionTrackerPluginFactory.class, 400); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 2697c5b03..b711f4617 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -33,7 +33,8 @@ import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePlugin; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; @@ -72,7 +73,8 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(ExecutionTimeConnectionPlugin.class, "plugin:executionTime"); put(AuroraConnectionTrackerPlugin.class, "plugin:auroraConnectionTracker"); put(LogQueryConnectionPlugin.class, "plugin:logQuery"); - put(DataCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataLocalCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataRemoteCachePlugin.class, "plugin:dataRemoteCache"); put(HostMonitoringConnectionPlugin.class, "plugin:efm"); put(software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin.class, "plugin:efm2"); put(FailoverConnectionPlugin.class, "plugin:failover"); diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index 7d59e83ff..60db2c8f4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; import software.amazon.jdbc.plugin.AwsSecretsManagerCacheHolder; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; import software.amazon.jdbc.plugin.OpenedConnectionTracker; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; @@ -430,7 +430,7 @@ public static void clearCaches() { CustomEndpointMonitorImpl.clearCache(); OpenedConnectionTracker.clearCache(); AwsSecretsManagerCacheHolder.clearCache(); - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); FederatedAuthCacheHolder.clearCache(); OktaAuthCacheHolder.clearCache(); IamAuthCacheHolder.clearCache(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java deleted file mode 100644 index a7cf53a9c..000000000 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java +++ /dev/null @@ -1,1239 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin; - -import java.io.InputStream; -import java.io.Reader; -import java.math.BigDecimal; -import java.net.URL; -import java.sql.Array; -import java.sql.Blob; -import java.sql.Clob; -import java.sql.Date; -import java.sql.NClob; -import java.sql.Ref; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.RowId; -import java.sql.SQLException; -import java.sql.SQLWarning; -import java.sql.SQLXML; -import java.sql.Statement; -import java.sql.Time; -import java.sql.Timestamp; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Calendar; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.logging.Logger; -import software.amazon.jdbc.AwsWrapperProperty; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -public class DataCacheConnectionPlugin extends AbstractConnectionPlugin { - - private static final Logger LOGGER = Logger.getLogger(DataCacheConnectionPlugin.class.getName()); - - private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( - Arrays.asList( - JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, - JdbcMethod.STATEMENT_EXECUTE.methodName, - JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, - JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, - JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, - JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName - ))); - - public static final AwsWrapperProperty DATA_CACHE_TRIGGER_CONDITION = new AwsWrapperProperty( - "dataCacheTriggerCondition", "false", - "A regular expression that, if it's matched, allows the plugin to cache SQL results."); - - protected static final Map dataCache = new ConcurrentHashMap<>(); - - protected final String dataCacheTriggerCondition; - - static { - PropertyDefinition.registerPluginProperties(DataCacheConnectionPlugin.class); - } - - private final TelemetryFactory telemetryFactory; - private final TelemetryCounter hitCounter; - private final TelemetryCounter missCounter; - private final TelemetryCounter totalCallsCounter; - private final TelemetryGauge cacheSizeGauge; - - public DataCacheConnectionPlugin(final PluginService pluginService, final Properties props) { - this.telemetryFactory = pluginService.getTelemetryFactory(); - this.dataCacheTriggerCondition = DATA_CACHE_TRIGGER_CONDITION.getString(props); - - this.hitCounter = telemetryFactory.createCounter("dataCache.cache.hit"); - this.missCounter = telemetryFactory.createCounter("dataCache.cache.miss"); - this.totalCallsCounter = telemetryFactory.createCounter("dataCache.cache.totalCalls"); - this.cacheSizeGauge = telemetryFactory.createGauge("dataCache.cache.size", () -> (long) dataCache.size()); - } - - public static void clearCache() { - dataCache.clear(); - } - - @Override - public Set getSubscribedMethods() { - return subscribedMethods; - } - - @Override - public T execute( - final Class resultClass, - final Class exceptionClass, - final Object methodInvokeOn, - final String methodName, - final JdbcCallable jdbcMethodFunc, - final Object[] jdbcMethodArgs) - throws E { - - if (StringUtils.isNullOrEmpty(this.dataCacheTriggerCondition) || resultClass != ResultSet.class) { - return jdbcMethodFunc.call(); - } - - if (this.totalCallsCounter != null) { - this.totalCallsCounter.inc(); - } - - ResultSet result; - boolean needToCache = false; - final String sql = getQuery(jdbcMethodArgs); - - if (!StringUtils.isNullOrEmpty(sql) && sql.matches(this.dataCacheTriggerCondition)) { - result = dataCache.get(sql); - if (result == null) { - needToCache = true; - if (this.missCounter != null) { - this.missCounter.inc(); - } - LOGGER.finest( - () -> Messages.get( - "DataCacheConnectionPlugin.queryResultsCached", - new Object[]{methodName, sql})); - } else { - if (this.hitCounter != null) { - this.hitCounter.inc(); - } - try { - result.beforeFirst(); - } catch (final SQLException ex) { - if (exceptionClass.isAssignableFrom(ex.getClass())) { - throw exceptionClass.cast(ex); - } - throw new RuntimeException(ex); - } - return resultClass.cast(result); - } - } - - result = (ResultSet) jdbcMethodFunc.call(); - - if (needToCache) { - final ResultSet cachedResultSet; - try { - cachedResultSet = new CachedResultSet(result); - dataCache.put(sql, cachedResultSet); - cachedResultSet.beforeFirst(); - return resultClass.cast(cachedResultSet); - } catch (final SQLException ex) { - // ignore exception - } - } - - return resultClass.cast(result); - } - - protected String getQuery(final Object[] jdbcMethodArgs) { - - // Get query from method argument - if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { - return jdbcMethodArgs[0].toString(); - } - return null; - } - - public static class CachedRow { - protected final HashMap columnByIndex = new HashMap<>(); - protected final HashMap columnByName = new HashMap<>(); - - public void put(final int columnIndex, final String columnName, final Object columnValue) { - columnByIndex.put(columnIndex, columnValue); - columnByName.put(columnName, columnValue); - } - - @SuppressWarnings("unused") - public Object get(final int columnIndex) { - return columnByIndex.get(columnIndex); - } - - @SuppressWarnings("unused") - public Object get(final String columnName) { - return columnByName.get(columnName); - } - } - - @SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) - public static class CachedResultSet implements ResultSet { - - protected ArrayList rows; - protected int currentRow; - - public CachedResultSet(final ResultSet resultSet) throws SQLException { - - final ResultSetMetaData md = resultSet.getMetaData(); - final int columns = md.getColumnCount(); - rows = new ArrayList<>(); - - while (resultSet.next()) { - final CachedRow row = new CachedRow(); - for (int i = 1; i <= columns; ++i) { - row.put(i, md.getColumnName(i), resultSet.getObject(i)); - } - rows.add(row); - } - currentRow = -1; - } - - @Override - public boolean next() throws SQLException { - if (rows.size() == 0 || isLast()) { - return false; - } - currentRow++; - return true; - } - - @Override - public void close() throws SQLException { - currentRow = rows.size() - 1; - } - - @Override - public boolean wasNull() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void clearWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getCursorName() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public ResultSetMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByIndex.containsKey(columnIndex)) { - return null; // column index out of boundaries - } - return row.columnByIndex.get(columnIndex); - } - - @Override - public Object getObject(final String columnLabel) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByName.containsKey(columnLabel)) { - return null; // column name not found - } - return row.columnByName.get(columnLabel); - } - - @Override - public int findColumn(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isBeforeFirst() throws SQLException { - return this.currentRow < 0; - } - - @Override - public boolean isAfterLast() throws SQLException { - return this.currentRow >= this.rows.size(); - } - - @Override - public boolean isFirst() throws SQLException { - return this.currentRow == 0 && this.rows.size() > 0; - } - - @Override - public boolean isLast() throws SQLException { - return this.currentRow == (this.rows.size() - 1) && this.rows.size() > 0; - } - - @Override - public void beforeFirst() throws SQLException { - this.currentRow = -1; - } - - @Override - public void afterLast() throws SQLException { - this.currentRow = this.rows.size(); - } - - @Override - public boolean first() throws SQLException { - this.currentRow = 0; - return this.currentRow < this.rows.size(); - } - - @Override - public boolean last() throws SQLException { - this.currentRow = this.rows.size() - 1; - return this.currentRow >= 0; - } - - @Override - public int getRow() throws SQLException { - return this.currentRow + 1; - } - - @Override - public boolean absolute(final int row) throws SQLException { - if (row > 0) { - this.currentRow = row - 1; - } else { - this.currentRow = this.rows.size() + row; - } - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean relative(final int rows) throws SQLException { - this.currentRow += rows; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean previous() throws SQLException { - this.currentRow--; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public void setFetchDirection(final int direction) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchDirection() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setFetchSize(final int rows) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchSize() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getType() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getConcurrency() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowUpdated() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowInserted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowDeleted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final int columnIndex, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final int columnIndex, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final int columnIndex, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final int columnIndex, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final int columnIndex, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final int columnIndex, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final int columnIndex, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final int columnIndex, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final int columnIndex, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final String columnLabel, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final String columnLabel, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final String columnLabel, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final String columnLabel, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final String columnLabel, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final String columnLabel, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final String columnLabel, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final String columnLabel, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final String columnLabel, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void insertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void deleteRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void refreshRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void cancelRowUpdates() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToInsertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToCurrentRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Statement getStatement() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final String columnLabel, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final int columnIndex, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final String columnLabel, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final int columnIndex, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final String columnLabel, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final int columnIndex, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final String columnLabel, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getHoldability() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClosed() throws SQLException { - return false; - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final int columnIndex, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final String columnLabel, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings("checkstyle:MethodName") - public NClob getNClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob getNClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T getObject(final int columnIndex, final Class type) throws SQLException { - return type.cast(getObject(columnIndex)); - } - - @Override - public T getObject(final String columnLabel, final Class type) throws SQLException { - return type.cast(getObject(columnLabel)); - } - - @Override - public T unwrap(final Class iface) throws SQLException { - return iface == ResultSet.class ? iface.cast(this) : null; - } - - @Override - public boolean isWrapperFor(final Class iface) throws SQLException { - return iface != null && iface.isAssignableFrom(this.getClass()); - } - } -} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java new file mode 100644 index 000000000..d162b6229 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -0,0 +1,741 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisCredentials; +import io.lettuce.core.RedisCredentialsProvider; +import io.lettuce.core.RedisURI; +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.SetArgs; +import io.lettuce.core.api.StatefulConnection; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.ClientResources; +import io.lettuce.core.cluster.RedisClusterClient; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; +import java.util.logging.Logger; +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.commons.pool2.PooledObject; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +/** + * Abstraction for a cache connection that can be pinged. + * Hides cache-client implementation details (Lettuce/Glide) from CacheMonitor. + */ +interface CachePingConnection { + /** + * Pings the cache server to check health. + * @return true if ping successful (PONG received), false otherwise + */ + boolean ping(); + + /** + * Checks if the connection is open. + * @return true if connection is open, false otherwise + */ + boolean isOpen(); + + /** + * Closes the connection. + */ + void close(); +} + +// Abstraction layer on top of a connection to a remote cache server +public class CacheConnection { + private static final Logger LOGGER = Logger.getLogger(CacheConnection.class.getName()); + + private static final int DEFAULT_POOL_MIN_IDLE = 0; + private static final int DEFAULT_MAX_POOL_SIZE = 200; + private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100; + private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30; + + private static final ReentrantLock connectionInitializationLock = new ReentrantLock(); + + private final String cacheRwServerAddr; // read-write cache server + private final String cacheRoServerAddr; // read-only cache server + private final String[] defaultCacheServerHostAndPort; + private MessageDigest msgHashDigest = null; + + protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRw", + null, + "The cache read-write server endpoint address."); + + protected static final AwsWrapperProperty CACHE_RO_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRo", + null, + "The cache read-only server endpoint address."); + + protected static final AwsWrapperProperty CACHE_USE_SSL = + new AwsWrapperProperty( + "cacheUseSSL", + "true", + "Whether to use SSL for cache connections."); + + protected static final AwsWrapperProperty CACHE_IAM_REGION = + new AwsWrapperProperty( + "cacheIamRegion", + null, + "AWS region for ElastiCache IAM authentication."); + + protected static final AwsWrapperProperty CACHE_USERNAME = + new AwsWrapperProperty( + "cacheUsername", + null, + "Username for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_PASSWORD = + new AwsWrapperProperty( + "cachePassword", + null, + "Password for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_NAME = + new AwsWrapperProperty( + "cacheName", + null, + "Explicit cache name for ElastiCache IAM authentication. "); + + protected static final AwsWrapperProperty CACHE_CONNECTION_TIMEOUT = + new AwsWrapperProperty( + "cacheConnectionTimeout", + "2000", + "Cache connection request timeout duration in milliseconds."); + + protected static final AwsWrapperProperty CACHE_CONNECTION_POOL_SIZE = + new AwsWrapperProperty( + "cacheConnectionPoolSize", + "20", + "Cache connection pool size."); + + protected static final AwsWrapperProperty FAIL_WHEN_CACHE_DOWN = + new AwsWrapperProperty( + "failWhenCacheDown", + "false", + "Whether to throw SQLException on cache failures under Degraded mode."); + + // Adding support for read and write connection pools to the remote cache server + private volatile GenericObjectPool> readConnectionPool; + private volatile GenericObjectPool> writeConnectionPool; + // Cache endpoint registry to hold connection pools for multi end points + private static final ConcurrentHashMap>> endpointToPoolRegistry = new ConcurrentHashMap<>(); + + private final boolean useSSL; + private final boolean iamAuthEnabled; + private final String cacheIamRegion; + private final String cacheUsername; + private final String cacheName; + private final String cachePassword; + private final Duration cacheConnectionTimeout; + private final int cacheConnectionPoolSize; + private final Properties awsProfileProperties; + private final AwsCredentialsProvider credentialsProvider; + private final boolean failWhenCacheDown; + private final TelemetryFactory telemetryFactory; + private final long inFlightWriteSizeLimitBytes; + private final boolean healthCheckInHealthyState; + private volatile boolean cacheMonitorRegistered = false; + private volatile Boolean isClusterMode = null; // null = not yet detected, true = CME, false = CMD + + static { + PropertyDefinition.registerPluginProperties(CacheConnection.class); + } + + /** + * Wraps a StatefulConnection (either StatefulRedisConnection or StatefulRedisClusterConnection) + * and exposes only ping functionality. + */ + private static class PingConnection implements CachePingConnection { + private final StatefulConnection connection; + + PingConnection(StatefulConnection connection) { + this.connection = connection; + } + + @Override + public boolean ping() { + try { + if (!connection.isOpen()) { + return false; + } + + // Cast to appropriate type to access sync() method + String result; + if (connection instanceof StatefulRedisClusterConnection) { + result = ((StatefulRedisClusterConnection) connection).sync().ping(); + } else { + result = ((StatefulRedisConnection) connection).sync().ping(); + } + return "PONG".equalsIgnoreCase(result); + } catch (Exception e) { + return false; + } + } + + @Override + public boolean isOpen() { + return connection.isOpen(); + } + + @Override + public void close() { + try { + connection.close(); + } catch (Exception e) { + // Ignore close errors + } + } + } + + public CacheConnection(final Properties properties, TelemetryFactory telemetryFactory) { + this.telemetryFactory = telemetryFactory; + this.inFlightWriteSizeLimitBytes = CacheMonitor.CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT.getLong(properties); + this.healthCheckInHealthyState = CacheMonitor.CACHE_HEALTH_CHECK_IN_HEALTHY_STATE.getBoolean(properties); + + this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); + this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); + this.useSSL = Boolean.parseBoolean(CACHE_USE_SSL.getString(properties)); + this.cacheName = CACHE_NAME.getString(properties); + this.cacheIamRegion = CACHE_IAM_REGION.getString(properties); + this.cacheUsername = CACHE_USERNAME.getString(properties); + this.cachePassword = CACHE_PASSWORD.getString(properties); + this.cacheConnectionTimeout = Duration.ofMillis(CACHE_CONNECTION_TIMEOUT.getInteger(properties)); + this.cacheConnectionPoolSize = CACHE_CONNECTION_POOL_SIZE.getInteger(properties); + if (this.cacheConnectionPoolSize <= 0 || this.cacheConnectionPoolSize > DEFAULT_MAX_POOL_SIZE) { + throw new IllegalArgumentException( + "Cache connection pool size must be within valid range: 1-" + DEFAULT_MAX_POOL_SIZE + ", but was: " + this.cacheConnectionPoolSize); + } + this.failWhenCacheDown = FAIL_WHEN_CACHE_DOWN.getBoolean(properties); + this.iamAuthEnabled = !StringUtils.isNullOrEmpty(this.cacheIamRegion); + boolean hasTraditionalAuth = !StringUtils.isNullOrEmpty(this.cachePassword); + // Validate authentication configuration + if (this.iamAuthEnabled && hasTraditionalAuth) { + throw new IllegalArgumentException( + "Cannot specify both IAM authentication (cacheIamRegion) and traditional authentication (cachePassword). Choose one authentication method."); + } + if (this.cacheRwServerAddr == null) { + throw new IllegalArgumentException("Cache endpoint address is required"); + } + this.defaultCacheServerHostAndPort = getHostnameAndPort(this.cacheRwServerAddr); + if (this.iamAuthEnabled) { + if (this.cacheUsername == null || this.defaultCacheServerHostAndPort[0] == null || this.cacheName == null) { + throw new IllegalArgumentException("IAM authentication requires cache name, username, region, and hostname"); + } + } + if (PropertyDefinition.AWS_PROFILE.getString(properties) != null) { + this.awsProfileProperties = new Properties(); + this.awsProfileProperties.setProperty( + PropertyDefinition.AWS_PROFILE.name, + PropertyDefinition.AWS_PROFILE.getString(properties) + ); + } else { + this.awsProfileProperties = null; + } + if (this.iamAuthEnabled) { + // Handle null case + Properties propsToPass = (this.awsProfileProperties != null) + ? this.awsProfileProperties + : new Properties(); + this.credentialsProvider = AwsCredentialsManager.getProvider(null, propsToPass); + } else { + this.credentialsProvider = null; + } + } + + // for unit testing only + public CacheConnection(final Properties properties) { + this(properties, null); + } + + /** + * Detects whether the Redis endpoint is running in cluster mode by executing INFO command. + * Caches the result to avoid repeated detection. The caller of this function needs to hold a lock for thread safety. + */ + private void detectClusterMode() { + if (this.isClusterMode != null) { + return; + } + + String[] hostnameAndPort = getHostnameAndPort(this.cacheRwServerAddr); + RedisURI redisUri = buildRedisURI(hostnameAndPort[0], Integer.parseInt(hostnameAndPort[1])); + + ClientResources resources = ClientResources.builder().build(); + + try (RedisClient client = RedisClient.create(resources, redisUri); + StatefulRedisConnection conn = client.connect(new ByteArrayCodec())) { + + String infoOutput = conn.sync().info("cluster"); + boolean clusterEnabled = false; + if (infoOutput != null) { + // Parse the INFO output line by line + for (String line : infoOutput.split("\r?\n")) { + if (line.startsWith("cluster_enabled:")) { + String value = line.substring("cluster_enabled:".length()).trim(); + clusterEnabled = "1".equals(value); + break; + } + } + } + this.isClusterMode = clusterEnabled; + // TODO: remove this log in final version + LOGGER.info("Detected cache mode: " + (this.isClusterMode ? "CLUSTER" : "STANDALONE") + + " for endpoint " + hostnameAndPort[0] + ":" + hostnameAndPort[1]); + + } catch (Exception e) { + LOGGER.warning("Failed to detect cluster mode, defaulting to single-shard: " + e.getMessage()); + this.isClusterMode = false; + } finally { + try { + resources.shutdown().get(5, TimeUnit.SECONDS); + } catch (Exception ignored) {} + } + } + + /* Here we check if we need to initialise connection pool for read or write to cache. + With isRead we check if we need to initialise connection pool for read or write to cache. + If isRead is true, we initialise connection pool for read. + If isRead is false, we initialise connection pool for write. + */ + private void initializeCacheConnectionIfNeeded(boolean isRead) { + if (StringUtils.isNullOrEmpty(this.cacheRwServerAddr)) { + return; + } + + // Initialize the message digest + if (this.msgHashDigest == null) { + try { + this.msgHashDigest = MessageDigest.getInstance("SHA-384"); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-384 not supported", e); + } + } + + // return early if connection pool is already initialized before acquiring the lock + if ((isRead && this.readConnectionPool != null) || (!isRead && this.writeConnectionPool != null)) { + return; + } + + connectionInitializationLock.lock(); + try { + // Double check after lock is acquired + if ((isRead && this.readConnectionPool != null) || (!isRead && this.writeConnectionPool != null)) { + return; + } + // Detect cluster mode first (cached for reuse) + detectClusterMode(); + + // Register cluster with CacheMonitor on first cache operation (skip if telemetryFactory is null = test mode) + if (telemetryFactory != null && !this.cacheMonitorRegistered) { + CacheMonitor.registerCluster( + inFlightWriteSizeLimitBytes, healthCheckInHealthyState, telemetryFactory, + this.cacheRwServerAddr, this.cacheRoServerAddr, + this.useSSL, this.cacheConnectionTimeout, this.iamAuthEnabled, this.credentialsProvider, + this.cacheIamRegion, this.cacheName, this.cacheUsername, this.cachePassword + ); + this.cacheMonitorRegistered = true; + } + + if ((isRead && this.readConnectionPool == null) || (!isRead && this.writeConnectionPool == null)) { + createConnectionPool(isRead); + } + } finally { + connectionInitializationLock.unlock(); + } + } + + private void createConnectionPool(boolean isRead) { + try { + // cache server addr string is in the format ":" + String serverAddr = this.cacheRwServerAddr; + // If read-only server is specified, use it for the read-only connections + if (isRead && !StringUtils.isNullOrEmpty(this.cacheRoServerAddr)) { + serverAddr = this.cacheRoServerAddr; + } + String[] hostnameAndPort = getHostnameAndPort(serverAddr); + RedisURI redisUri = buildRedisURI(hostnameAndPort[0], Integer.parseInt(hostnameAndPort[1])); + + // Appending RW and RO tag to the server address to make it unique in case RO and RW has same endpoint + String poolKey = (isRead ? "RO:" : "RW:") + serverAddr; + GenericObjectPool> pool = endpointToPoolRegistry.get(poolKey); + + if (pool == null) { + GenericObjectPoolConfig> poolConfig = createPoolConfig(); + poolConfig.setMaxTotal(this.cacheConnectionPoolSize); + poolConfig.setMaxIdle(this.cacheConnectionPoolSize); + + pool = endpointToPoolRegistry.computeIfAbsent(poolKey, k -> + new GenericObjectPool<>( + new BasePooledObjectFactory>() { + public StatefulConnection create() { + return createRedisConnection(isRead, redisUri, isClusterMode); + } + public PooledObject> wrap(StatefulConnection connection) { + return new DefaultPooledObject<>(connection); + } + }, poolConfig) + ); + } + + if (isRead) { + this.readConnectionPool = pool; + } else { + this.writeConnectionPool = pool; + } + } catch (Exception e) { + String poolType = isRead ? "read" : "write"; + String errorMsg = String.format("Failed to create Cache %s connection pool", poolType); + LOGGER.warning(errorMsg + ": " + e.getMessage()); + throw new RuntimeException(errorMsg, e); + } + } + + private static GenericObjectPoolConfig> createPoolConfig() { + GenericObjectPoolConfig> poolConfig = new GenericObjectPoolConfig<>(); + poolConfig.setMinIdle(DEFAULT_POOL_MIN_IDLE); + poolConfig.setMaxWait(Duration.ofMillis(DEFAULT_MAX_BORROW_WAIT_MS)); + return poolConfig; + } + + /** + * Creates a Redis connection for either standalone or cluster mode. + * Returns StatefulConnection which works for both RedisClient and RedisClusterClient. + */ + private static StatefulConnection createRedisConnection( + boolean isReadOnly, RedisURI redisUri, boolean isClusterMode) { + + ClientResources resources = ClientResources.builder().build(); + StatefulConnection conn; + + if (isClusterMode) { + // Multi-shard cluster mode: use RedisClusterClient + RedisClusterClient client = RedisClusterClient.create(resources, redisUri); + conn = client.connect(new ByteArrayCodec()); + } else { + // Single-shard standalone mode: use RedisClient + RedisClient client = RedisClient.create(resources, redisUri); + conn = client.connect(new ByteArrayCodec()); + } + + // Set READONLY mode for RO endpoint + if (isReadOnly) { + try { + if (conn instanceof StatefulRedisClusterConnection) { + ((StatefulRedisClusterConnection) conn).sync().readOnly(); + } else { + ((StatefulRedisConnection) conn).sync().readOnly(); + } + } catch (RedisCommandExecutionException e) { + if (e.getMessage().contains("ERR This instance has cluster support disabled")) { + LOGGER.fine("Note: this cache cluster has cluster support disabled"); + } else { + LOGGER.fine("Note: READONLY command not supported or failed: " + e.getMessage()); + } + } + } + + return conn; + } + + /** + * Static helper to build RedisURI with authentication configuration. + * Used by both createPingConnection (static) and buildRedisURI (instance). + */ + private static RedisURI buildRedisURIStatic(String hostname, int port, boolean useSSL, Duration connectionTimeout, + boolean iamAuthEnabled, AwsCredentialsProvider credentialsProvider, String cacheIamRegion, + String cacheName, String cacheUsername, String cachePassword) { + + RedisURI.Builder uriBuilder = RedisURI.Builder.redis(hostname) + .withPort(port) + .withSsl(useSSL) + .withVerifyPeer(false) + .withLibraryName("aws-sql-jdbc-lettuce") + .withTimeout(connectionTimeout); + + if (iamAuthEnabled) { + // Create a credentials provider that Lettuce will call whenever authentication is needed + RedisCredentialsProvider redisCredentialsProvider = () -> { + // Create a cached token supplier that automatically refreshes tokens every 14.5 minutes + Supplier tokenSupplier = CachedSupplier.memoizeWithExpiration( + () -> { + ElastiCacheIamTokenUtility tokenUtility = new ElastiCacheIamTokenUtility(cacheName); + return tokenUtility.generateAuthenticationToken( + credentialsProvider, + Region.of(cacheIamRegion), + hostname, + port, + cacheUsername + ); + }, + TOKEN_CACHE_DURATION, + TimeUnit.SECONDS + ); + return Mono.just(RedisCredentials.just(cacheUsername, tokenSupplier.get())); + }; + uriBuilder.withAuthentication(redisCredentialsProvider); + } else if (!StringUtils.isNullOrEmpty(cachePassword)) { + uriBuilder.withAuthentication(cacheUsername, cachePassword); + } + + return uriBuilder.build(); + } + + /** + * Creates a cache ping connection with the specified configuration. + * This is a static helper that abstracts Lettuce-specific logic for CacheMonitor. + * Returns an interface to hide implementation details. + */ + static CachePingConnection createPingConnection(String hostname, int port, boolean isReadOnly, boolean useSSL, + Duration connectionTimeout, boolean iamAuthEnabled, AwsCredentialsProvider credentialsProvider, String cacheIamRegion, + String cacheName, String cacheUsername, String cachePassword) { + + try { + // Use the static helper to build RedisURI + RedisURI redisUri = buildRedisURIStatic(hostname, port, useSSL, connectionTimeout, iamAuthEnabled, + credentialsProvider, cacheIamRegion, cacheName, cacheUsername, cachePassword + ); + // Create Lettuce connection (use standalone mode for ping - works for both) + StatefulConnection conn = createRedisConnection(isReadOnly, redisUri, false); + + // Wrap in abstraction interface + return new PingConnection(conn); + + } catch (Exception e) { + LOGGER.fine(String.format("Failed to create ping connection for %s:%d: %s", + hostname, port, e.getMessage())); + return null; + } + } + + // Get the hash digest of the given key. + private byte[] computeHashDigest(byte[] key) { + this.msgHashDigest.update(key); + return this.msgHashDigest.digest(); + } + + public byte[] readFromCache(String key) throws SQLException { + // Check cluster state before attempting read + CacheMonitor.HealthState state = getClusterHealthStateFromCacheMonitor(); + if (!shouldProceedWithOperation(state)) { + if (failWhenCacheDown) { + throw new SQLException("Cache cluster is in DEGRADED state and failWhenCacheDown is enabled"); + } + return null; // Treat as cache miss + } + + boolean isBroken = false; + StatefulConnection conn = null; + // get a connection from the read connection pool + try { + initializeCacheConnectionIfNeeded(true); + conn = this.readConnectionPool.borrowObject(); + + // Cast to appropriate type and execute get command + byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8)); + if (conn instanceof StatefulRedisClusterConnection) { + return ((StatefulRedisClusterConnection) conn).sync().get(keyHash); + } else { + return ((StatefulRedisConnection) conn).sync().get(keyHash); + } + } catch (Exception e) { + if (conn != null) { + isBroken = true; + } + // Report error to CacheMonitor for the read endpoint + reportErrorToCacheMonitor(false, e, "READ"); + LOGGER.warning("Failed to read result from cache. Treating it as a cache miss: " + e.getMessage()); + return null; + } finally { + if (conn != null && this.readConnectionPool != null) { + try { + this.returnConnectionBackToPool(conn, isBroken, true); + } catch (Exception ex) { + LOGGER.warning("Error closing read connection: " + ex.getMessage()); + } + } + } + } + + protected void handleCompletedCacheWrite(StatefulConnection conn, long writeSize, Throwable ex) { + // Note: this callback upon completion of cache write is on a different thread + // Always decrement in-flight size (write completed, whether success or failure) + decrementInFlightSize(writeSize); + + if (ex != null) { + // Report error to CacheMonitor for RW endpoint + reportErrorToCacheMonitor(true, ex, "WRITE"); + if (this.writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, true, false); + } catch (Exception e) { + LOGGER.warning("Error returning broken write connection back to pool: " + e.getMessage()); + } + } + } else { + if (this.writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, false, false); + } catch (Exception e) { + LOGGER.warning("Error returning write connection back to pool: " + e.getMessage()); + } + } + } + } + + public void writeToCache(String key, byte[] value, int expiry) { + // Check cluster state before attempting write + CacheMonitor.HealthState state = getClusterHealthStateFromCacheMonitor(); + if (!shouldProceedWithOperation(state)) { + LOGGER.finest("Skipping cache write - cluster is DEGRADED"); + return; // Exit without writing + } + + StatefulConnection conn = null; + try { + initializeCacheConnectionIfNeeded(false); + + // Calculate write size and increment before borrowing connection + byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8)); + long writeSize = keyHash.length + value.length; + incrementInFlightSize(writeSize); + + try { + conn = this.writeConnectionPool.borrowObject(); + } catch (Exception borrowException) { + // Connection borrow failed (timeout/pool exhaustion) - decrement immediately + decrementInFlightSize(writeSize); + reportErrorToCacheMonitor(true, borrowException, "WRITE"); + return; + } + + // Get async commands and execute set operation based on connection type + StatefulConnection finalConn = conn; + + if (conn instanceof StatefulRedisClusterConnection) { + RedisAdvancedClusterAsyncCommands clusterAsyncCommands = + ((StatefulRedisClusterConnection) conn).async(); + clusterAsyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry)) + .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, writeSize, exception)); + } else { + RedisAsyncCommands asyncCommands = + ((StatefulRedisConnection) conn).async(); + asyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry)) + .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, writeSize, exception)); + } + + } catch (Exception e) { + // Connection failed, but we already incremented and will be able to detect shard level failures + reportErrorToCacheMonitor(true, e, "WRITE"); + + if (conn != null && this.writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, true, false); + } catch (Exception ex) { + LOGGER.warning("Error closing write connection: " + ex.getMessage()); + } + } + } + } + + private void returnConnectionBackToPool(StatefulConnection connection, boolean isConnectionBroken, boolean isRead) { + GenericObjectPool> pool = isRead ? this.readConnectionPool : this.writeConnectionPool; + if (isConnectionBroken) { + try { + pool.invalidateObject(connection); + } catch (Exception e) { + throw new RuntimeException("Could not invalidate connection for the pool", e); + } + } else { + pool.returnObject(connection); + } + } + + + protected RedisURI buildRedisURI(String hostname, int port) { + // Delegate to the static helper + return buildRedisURIStatic(hostname, port, this.useSSL, this.cacheConnectionTimeout, this.iamAuthEnabled, + this.credentialsProvider, this.cacheIamRegion, this.cacheName, this.cacheUsername, this.cachePassword + ); + } + + private String[] getHostnameAndPort(String serverAddr) { + return serverAddr.split(":"); + } + + protected CacheMonitor.HealthState getClusterHealthStateFromCacheMonitor() { + return CacheMonitor.getClusterState(this.cacheRwServerAddr, this.cacheRoServerAddr); + } + + protected void reportErrorToCacheMonitor(boolean isWrite, Throwable error, String operation) { + CacheMonitor.reportError(this.cacheRwServerAddr, this.cacheRoServerAddr, isWrite, error, operation); + } + + protected void incrementInFlightSize(long writeSize) { + CacheMonitor.incrementInFlightSizeStatic(this.cacheRwServerAddr, this.cacheRoServerAddr, writeSize); + } + + protected void decrementInFlightSize(long writeSize) { + CacheMonitor.decrementInFlightSizeStatic(this.cacheRwServerAddr, this.cacheRoServerAddr, writeSize); + } + + protected boolean shouldProceedWithOperation(CacheMonitor.HealthState state) { + return state != CacheMonitor.HealthState.DEGRADED; + } + + // Used for unit testing only + protected void setConnectionPools(GenericObjectPool> readPool, + GenericObjectPool> writePool) { + this.readConnectionPool = readPool; + this.writeConnectionPool = writePool; + } + + // Used for unit testing only + protected void triggerPoolInit(boolean isRead) { + initializeCacheConnectionIfNeeded(isRead); + } + + // Used for unit testing only - allows tests to bypass cluster detection + protected void setClusterMode(boolean clusterMode) { + this.isClusterMode = clusterMode; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheMonitor.java new file mode 100644 index 000000000..87e56e697 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheMonitor.java @@ -0,0 +1,527 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.RedisConnectionException; +import io.lettuce.core.RedisException; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import java.time.Duration; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +/** + * This class uses a background thread to monitor cache cluster health and manage state transitions. + * + * Implements a three-state machine (HEALTHY → SUSPECT → DEGRADED) with proactive health checks + * that only run when clusters are in SUSPECT or DEGRADED states. + */ +public class CacheMonitor implements Runnable { + + private static final Logger LOGGER = Logger.getLogger(CacheMonitor.class.getName()); + + // Singleton instance + private static volatile CacheMonitor instance; + private static final Object INSTANCE_LOCK = new Object(); + + private static final long THREAD_SLEEP_WHEN_INACTIVE_MILLIS = 100; + private static final int CACHE_HEALTH_CHECK_INTERVAL = 5; + private static final int CACHE_CONSECUTIVE_SUCCESS_THRESHOLD = 3; + private static final int CACHE_CONSECUTIVE_FAILURE_THRESHOLD = 3; + + // Track if monitor thread has been started + private static volatile boolean monitorThreadStarted = false; + private static final Object THREAD_START_LOCK = new Object(); + + // Configuration properties + public static final AwsWrapperProperty CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT = + new AwsWrapperProperty( + "cacheInFlightWriteSizeLimitBytes", + String.valueOf(50 * 1024 * 1024), // 50 MB + "Maximum in-flight write size in bytes before triggering degraded mode."); + + public static final AwsWrapperProperty CACHE_HEALTH_CHECK_IN_HEALTHY_STATE = + new AwsWrapperProperty( + "cacheHealthCheckInHealthyState", + "false", + "Whether to run health checks (pings) in healthy state."); + + private static final Map clusterStates = new ConcurrentHashMap<>(); + private final long inFlightWriteSizeLimitBytes; + private final boolean healthCheckInHealthyState; + private volatile boolean stopped = false; + + // Telemetry + private final TelemetryFactory telemetryFactory; + private static TelemetryCounter stateTransitionCounter; + private static TelemetryCounter healthCheckSuccessCounter; + private static TelemetryCounter healthCheckFailureCounter; + private static TelemetryCounter errorCounter; + private static TelemetryGauge consecutiveSuccessGauge; + private static TelemetryGauge consecutiveFailureGauge; + + /** + * Enum representing the health state of a cache cluster or endpoint. + */ + protected enum HealthState { + HEALTHY, + SUSPECT, + DEGRADED + } + + /* Categories of cache errors for classification and handling */ + protected enum ErrorCategory { + CONNECTION, // Network issues, timeouts, connection refused + COMMAND, // Invalid syntax, wrong arguments + DATA, // Serialization failures, data corruption + RESOURCE // Memory limits, cluster issues + } + + /** + * Represents a cache cluster with RW and optional RO endpoints. + * This is the single source of truth for cluster health state. + */ + static class ClusterHealthState { + final String rwEndpoint; + final String roEndpoint; // Null if no separate RO endpoint + + // Separate health states for RW and RO endpoints + volatile HealthState rwHealthState; + volatile HealthState roHealthState; + + // Consecutive result tracking for RW endpoint (only accessed by CacheMonitor thread) + int consecutiveRwSuccesses; + int consecutiveRwFailures; + + // Consecutive result tracking for RO endpoint (only accessed by CacheMonitor thread) + int consecutiveRoSuccesses; + int consecutiveRoFailures; + + // Memory pressure tracking (only for RW endpoint, accessed by multiple threads) + final AtomicLong inFlightWriteSizeBytes; + + // Ping connections - one for RW, one for RO if different + volatile CachePingConnection rwPingConnection; + volatile CachePingConnection roPingConnection; // Null if no separate RO endpoint + + // Configuration for creating ping connections + final boolean useSSL; + final Duration cacheConnectionTimeout; + final boolean iamAuthEnabled; + final AwsCredentialsProvider credentialsProvider; + final String cacheIamRegion; + final String cacheName; + final String cacheUsername; + final String cachePassword; + + ClusterHealthState(String rwEndpoint, String roEndpoint, boolean useSSL, Duration cacheConnectionTimeout, + boolean iamAuthEnabled, AwsCredentialsProvider credentialsProvider, String cacheIamRegion, + String cacheName, String cacheUsername, String cachePassword) { + this.rwEndpoint = rwEndpoint; + // If roEndpoint equals rwEndpoint, treat it as null + this.roEndpoint = (roEndpoint != null && roEndpoint.equals(rwEndpoint)) ? null : roEndpoint; + this.useSSL = useSSL; + this.cacheConnectionTimeout = cacheConnectionTimeout; + this.iamAuthEnabled = iamAuthEnabled; + this.credentialsProvider = credentialsProvider; + this.cacheIamRegion = cacheIamRegion; + this.cacheName = cacheName; + this.cacheUsername = cacheUsername; + this.cachePassword = cachePassword; + + // Initialize health states + this.rwHealthState = HealthState.HEALTHY; + this.roHealthState = HealthState.HEALTHY; + + // Initialize counters + this.consecutiveRwSuccesses = 0; + this.consecutiveRwFailures = 0; + this.consecutiveRoSuccesses = 0; + this.consecutiveRoFailures = 0; + this.inFlightWriteSizeBytes = new AtomicLong(0); + } + + String getClusterKey() { + return generateClusterKey(rwEndpoint, roEndpoint); + } + + static String generateClusterKey(String rwEndpoint, String roEndpoint) { + String normalizedRo = (roEndpoint != null && roEndpoint.equals(rwEndpoint)) ? null : roEndpoint; + return normalizedRo == null ? rwEndpoint + "|" : rwEndpoint + "|" + normalizedRo; + } + + /** + * Transition health state for a specific endpoint. + */ + void transitionToState(HealthState newState, boolean isRw, String triggerReason, TelemetryCounter counter) { + String endpoint = isRw ? rwEndpoint : roEndpoint; + HealthState oldState = isRw ? rwHealthState : roHealthState; + + // Update state + if (isRw) { + rwHealthState = newState; + consecutiveRwSuccesses = 0; + consecutiveRwFailures = 0; + } else { + roHealthState = newState; + consecutiveRoSuccesses = 0; + consecutiveRoFailures = 0; + } + + // Emit telemetry + if (counter != null) { + counter.inc(); + } + + // Log state transition + if (!triggerReason.startsWith("recoverable_error_")) { + String logMessage = String.format( + "[%s to %s] Cache endpoint %s (%s) health state transitioned (trigger: %s)", + oldState, newState, endpoint, isRw ? "RW" : "RO", triggerReason); + + LOGGER.fine(logMessage); + } + } + + /** + * Get the overall cluster health state based on both endpoints. + */ + HealthState getClusterHealthState() { + // If no separate RO endpoint, return RW state + if (roEndpoint == null) { + return rwHealthState; + } + + // Compute cluster state based on both endpoints + if (rwHealthState == HealthState.HEALTHY && roHealthState == HealthState.HEALTHY) { + return HealthState.HEALTHY; + } else if (rwHealthState == HealthState.DEGRADED || roHealthState == HealthState.DEGRADED) { + return HealthState.DEGRADED; + } else { + return HealthState.SUSPECT; // Mixed states + } + } + } + + protected static void registerCluster(long inFlightWriteSizeLimitBytes, boolean healthCheckInHealthyState, + TelemetryFactory telemetryFactory, + String rwEndpoint, String roEndpoint, boolean useSSL, Duration cacheConnectionTimeout, + boolean iamAuthEnabled, AwsCredentialsProvider credentialsProvider, String cacheIamRegion, + String cacheName, String cacheUsername, String cachePassword, boolean createPingConnection, + boolean startMonitorThread) { + if (getCluster(rwEndpoint, roEndpoint) != null) { + return; // if cluster has already been registered + } + if (instance == null) { + synchronized (INSTANCE_LOCK) { + if (instance == null) { + instance = new CacheMonitor(inFlightWriteSizeLimitBytes, healthCheckInHealthyState, telemetryFactory); + LOGGER.info("Created CacheMonitor instance " + (telemetryFactory != null ? "with" : "without") + " telemetry"); + } + } + } + ClusterHealthState clusterState = new ClusterHealthState(rwEndpoint, roEndpoint, useSSL, cacheConnectionTimeout, + iamAuthEnabled, credentialsProvider, cacheIamRegion, cacheName, cacheUsername, cachePassword); + ClusterHealthState existingCluster = clusterStates.putIfAbsent(clusterState.getClusterKey(), clusterState); + if (existingCluster == null) { + LOGGER.info(() -> "Registered cluster: " + clusterState.getClusterKey()); + if (createPingConnection) { + instance.createInitialPingConnections(clusterState); + } + } + if (startMonitorThread) { + instance.startMonitoring(); + } + } + + protected static void registerCluster(long inFlightWriteSizeLimitBytes, boolean healthCheckInHealthyState, + TelemetryFactory telemetryFactory, + String rwEndpoint, String roEndpoint, boolean useSSL, + Duration cacheConnectionTimeout, boolean iamAuthEnabled, + AwsCredentialsProvider credentialsProvider, String cacheIamRegion, + String cacheName, String cacheUsername, String cachePassword) { + registerCluster(inFlightWriteSizeLimitBytes, healthCheckInHealthyState, telemetryFactory, + rwEndpoint, roEndpoint, useSSL, + cacheConnectionTimeout, iamAuthEnabled, credentialsProvider, + cacheIamRegion, cacheName, cacheUsername, cachePassword, true, true); + } + + private CacheMonitor(long inFlightWriteSizeLimitBytes, boolean healthCheckInHealthyState, TelemetryFactory telemetryFactory) { + this.telemetryFactory = telemetryFactory; + this.inFlightWriteSizeLimitBytes = inFlightWriteSizeLimitBytes; + this.healthCheckInHealthyState = healthCheckInHealthyState; + + if (telemetryFactory != null && stateTransitionCounter == null) { + stateTransitionCounter = telemetryFactory.createCounter("JdbcCacheStateTransitionCount"); + healthCheckSuccessCounter = telemetryFactory.createCounter("JdbcCacheHealthCheckSuccessCount"); + healthCheckFailureCounter = telemetryFactory.createCounter("JdbcCacheHealthCheckFailureCount"); + errorCounter = telemetryFactory.createCounter("JdbcCacheErrorCount"); + + consecutiveSuccessGauge = telemetryFactory.createGauge("JdbcCacheConsecutiveSuccessCount", + () -> clusterStates.values().stream() + .mapToLong(c -> Math.max(c.consecutiveRwSuccesses, c.consecutiveRoSuccesses)) + .max().orElse(0L)); + consecutiveFailureGauge = telemetryFactory.createGauge("JdbcCacheConsecutiveFailureCount", + () -> clusterStates.values().stream() + .mapToLong(c -> Math.max(c.consecutiveRwFailures, c.consecutiveRoFailures)) + .max().orElse(0L)); + } + } + + private void createInitialPingConnections(ClusterHealthState cluster) { + cluster.rwPingConnection = createPingConnection(cluster, false); + if (cluster.roEndpoint != null) { + cluster.roPingConnection = createPingConnection(cluster, true); + } + } + + // Used for unit testing only + protected void setPingConnections(ClusterHealthState cluster, + CachePingConnection rwConnection, + CachePingConnection roConnection) { + cluster.rwPingConnection = rwConnection; + cluster.roPingConnection = roConnection; + } + + private CachePingConnection createPingConnection(ClusterHealthState cluster, boolean isReadOnly) { + String[] hostPort = isReadOnly ? cluster.roEndpoint.split(":") : cluster.rwEndpoint.split(":"); + return CacheConnection.createPingConnection(hostPort[0], Integer.parseInt(hostPort[1]), + isReadOnly, cluster.useSSL, cluster.cacheConnectionTimeout, cluster.iamAuthEnabled, cluster.credentialsProvider, + cluster.cacheIamRegion, cluster.cacheName, cluster.cacheUsername, cluster.cachePassword); + } + + private static ClusterHealthState getCluster(String rwEndpoint, String roEndpoint) { + return clusterStates.get(ClusterHealthState.generateClusterKey(rwEndpoint, roEndpoint)); + } + + protected static HealthState getClusterState(String rwEndpoint, String roEndpoint) { + if (instance == null) { + return HealthState.HEALTHY; + } + ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + return cluster != null ? cluster.getClusterHealthState() : HealthState.HEALTHY; + } + + private void startMonitoring() { + if (!monitorThreadStarted) { + synchronized (THREAD_START_LOCK) { + if (!monitorThreadStarted) { + Thread thread = new Thread(this, "CacheMonitorThread"); + thread.setDaemon(true); + thread.start(); + monitorThreadStarted = true; + LOGGER.info("Started CacheMonitor thread"); + } + } + } + } + + private static ErrorCategory classifyError(Throwable error) { + if (error instanceof RedisConnectionException) { + return ErrorCategory.CONNECTION; + } + if (error instanceof RedisCommandExecutionException) { + String msg = error.getMessage(); + if (msg == null) return ErrorCategory.RESOURCE; + if (msg.contains("READONLY") || msg.contains("WRONGTYPE") || msg.contains("MOVED") || msg.contains("ASK")) { + return ErrorCategory.COMMAND; + } + if (msg.contains("OOM") || msg.contains("CLUSTERDOWN") || msg.contains("LOADING") || + msg.contains("NOAUTH") || msg.contains("WRONGPASS")) { + return ErrorCategory.RESOURCE; + } + return ErrorCategory.COMMAND; + } + if (error instanceof RedisException) { + return ErrorCategory.CONNECTION; + } + return ErrorCategory.DATA; + } + + private static boolean isRecoverableError(ErrorCategory category) { + return category != ErrorCategory.DATA; + } + + protected static void reportError(String rwEndpoint, String roEndpoint, boolean isRw, + Throwable error, String operation) { + ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + if (cluster == null) { + LOGGER.warning("Error report for unregistered cluster: RW=" + rwEndpoint + ", RO=" + roEndpoint); + return; + } + ErrorCategory category = classifyError(error); + if (errorCounter != null) { + errorCounter.inc(); + } + if (!isRecoverableError(category)) { + LOGGER.info(() -> "Non-recoverable error (" + category + ") for " + + (isRw ? rwEndpoint : roEndpoint) + ": " + error.getMessage()); + return; + } + synchronized (cluster) { + HealthState currentState = isRw ? cluster.rwHealthState : cluster.roHealthState; + if (currentState == HealthState.HEALTHY) { + LOGGER.warning(String.format("[HEALTHY→SUSPECT] %s %s failed: %s - %s", + isRw ? rwEndpoint : roEndpoint, operation, category, error.getMessage())); + cluster.transitionToState(HealthState.SUSPECT, isRw, + "recoverable_error_" + category, stateTransitionCounter); + } + } + } + + private void incrementInFlightSize(String rwEndpoint, String roEndpoint, long bytes) { + ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + if (cluster != null) { + long newSize = cluster.inFlightWriteSizeBytes.addAndGet(bytes); + synchronized (cluster) { + if (newSize > inFlightWriteSizeLimitBytes && cluster.rwHealthState != HealthState.DEGRADED) { + LOGGER.warning("In-flight write size limit exceeded: " + newSize); + cluster.transitionToState(HealthState.DEGRADED, true, "memory_pressure", stateTransitionCounter); + } + } + } + } + + private void decrementInFlightSize(String rwEndpoint, String roEndpoint, long bytes) { + ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + if (cluster != null) { + cluster.inFlightWriteSizeBytes.updateAndGet(x -> Math.max(0, x - bytes)); + } + } + + protected static void incrementInFlightSizeStatic(String rwEndpoint, String roEndpoint, long bytes) { + if (instance != null) { + instance.incrementInFlightSize(rwEndpoint, roEndpoint, bytes); + } + } + + protected static void decrementInFlightSizeStatic(String rwEndpoint, String roEndpoint, long bytes) { + if (instance != null) { + instance.decrementInFlightSize(rwEndpoint, roEndpoint, bytes); + } + } + + @Override + public void run() { + LOGGER.info("Cache monitor thread started"); + try { + this.stopped = false; + while (!stopped) { + try { + long start = System.currentTimeMillis(); + boolean hasActiveMonitoring = false; + for (ClusterHealthState cluster : clusterStates.values()) { + if (cluster.getClusterHealthState() == HealthState.HEALTHY && !healthCheckInHealthyState) { + continue; + } + hasActiveMonitoring = true; + executePing(cluster, true); + if (cluster.roEndpoint != null) { + executePing(cluster, false); + } + } + long duration = System.currentTimeMillis() - start; + long target = hasActiveMonitoring ? TimeUnit.SECONDS.toMillis(CACHE_HEALTH_CHECK_INTERVAL) + : THREAD_SLEEP_WHEN_INACTIVE_MILLIS; + sleep(Math.max(0, target - duration)); + } catch (InterruptedException e) { + throw e; + } catch (Exception e) { + LOGGER.warning("Cache monitoring exception: " + e.getMessage()); + } + } + } catch (InterruptedException e) { + LOGGER.warning("Cache monitor interrupted"); + } finally { + this.stopped = true; + LOGGER.info("Cache monitor stopped"); + } + } + + private void executePing(ClusterHealthState cluster, boolean isRw) { + String endpoint = isRw ? cluster.rwEndpoint : cluster.roEndpoint; + boolean success = ping(cluster, isRw); + + HealthState currentState = isRw ? cluster.rwHealthState : cluster.roHealthState; + if (success) { + if (healthCheckSuccessCounter != null) { + healthCheckSuccessCounter.inc(); + } + if (isRw) { + cluster.consecutiveRwSuccesses++; + cluster.consecutiveRwFailures = 0; + } else { + cluster.consecutiveRoSuccesses++; + cluster.consecutiveRoFailures = 0; + } + int consecutive_success = isRw ? cluster.consecutiveRwSuccesses : cluster.consecutiveRoSuccesses; + if (consecutive_success >= CACHE_CONSECUTIVE_SUCCESS_THRESHOLD) { + synchronized (cluster) { + if (currentState == HealthState.SUSPECT) { + cluster.transitionToState(HealthState.HEALTHY, isRw, "consecutive_successes", stateTransitionCounter); + } else if (currentState == HealthState.DEGRADED && + cluster.inFlightWriteSizeBytes.get() < inFlightWriteSizeLimitBytes) { + cluster.transitionToState(HealthState.HEALTHY, isRw, + "consecutive_successes_and_memory_recovered", stateTransitionCounter); + } + } + } + } else { + LOGGER.warning(() -> "Ping failed for " + endpoint + " (" + (isRw ? "RW" : "RO") + ")"); + if (healthCheckFailureCounter != null) healthCheckFailureCounter.inc(); + if (isRw) { + cluster.consecutiveRwFailures++; + cluster.consecutiveRwSuccesses = 0; + } else { + cluster.consecutiveRoFailures++; + cluster.consecutiveRoSuccesses = 0; + } + int consecutive_failure = isRw ? cluster.consecutiveRwFailures : cluster.consecutiveRoFailures; + synchronized (cluster) { + if (currentState == HealthState.HEALTHY && consecutive_failure >= 1) { + cluster.transitionToState(HealthState.SUSPECT, isRw, "first_failure", stateTransitionCounter); + } else if (currentState == HealthState.SUSPECT && consecutive_failure >= CACHE_CONSECUTIVE_FAILURE_THRESHOLD) { + cluster.transitionToState(HealthState.DEGRADED, isRw, "consecutive_failures", stateTransitionCounter); + } + } + } + } + + protected void sleep(long duration) throws InterruptedException { + TimeUnit.MILLISECONDS.sleep(duration); + } + + private boolean ping(ClusterHealthState cluster, boolean isRw) { + CachePingConnection conn = isRw ? cluster.rwPingConnection : cluster.roPingConnection; + if (conn == null || !conn.isOpen()) { + return false; + } + try { + return conn.ping(); + } catch (Exception e) { + LOGGER.warning("Ping failed for " + (isRw ? cluster.rwEndpoint : cluster.roEndpoint) + " (" + (isRw ? "RW" : "RO") + "): " + e.getMessage()); + return false; + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java new file mode 100644 index 000000000..58d466ea2 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -0,0 +1,1441 @@ +package software.amazon.jdbc.plugin.cache; + +import org.checkerframework.checker.nullness.qual.Nullable; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.IOException; +import java.io.Reader; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.net.MalformedURLException; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.ZonedDateTime; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.Calendar; +import java.util.TimeZone; + +public class CachedResultSet implements ResultSet { + + public static class CachedRow { + private final Object[] rowData; + final byte[] @Nullable [] rawData; + + public CachedRow(int numColumns) { + rowData = new Object[numColumns]; + rawData = new byte[numColumns][]; + } + + private void checkColumnIndex(final int columnIndex) throws SQLException { + if (columnIndex < 1 || columnIndex > rowData.length) { + throw new SQLException("Invalid Column Index when operating CachedRow: " + columnIndex); + } + } + + public void put(final int columnIndex, final Object columnValue) throws SQLException { + checkColumnIndex(columnIndex); + rowData[columnIndex-1] = columnValue; + } + + public void putRaw(final int columnIndex, final byte[] rawColumnValue) throws SQLException { + checkColumnIndex(columnIndex); + rawData[columnIndex-1] = rawColumnValue; + } + + public Object get(final int columnIndex) throws SQLException { + checkColumnIndex(columnIndex); + // De-serialize the data object from raw bytes if needed. + if (rowData[columnIndex-1] == null && rawData[columnIndex-1] != null) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(rawData[columnIndex - 1]); + ObjectInputStream ois = new ObjectInputStream(bis)) { + rowData[columnIndex - 1] = ois.readObject(); + rawData[columnIndex - 1] = null; + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing caching resultSet for column: " + columnIndex, e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing caching resultSet for column: " + columnIndex, e); + } + } + return rowData[columnIndex - 1]; + } + } + + protected ArrayList rows; + protected int currentRow; + protected boolean wasNullFlag; + private final CachedResultSetMetaData metadata; + protected static final ZoneId defaultTimeZoneId = ZoneId.systemDefault(); + protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); + private final HashMap columnNames; + private volatile boolean closed; + + /** + * Create a CachedResultSet out of the original ResultSet queried from the database. + * @param resultSet The ResultSet queried from the underlying database (not a CachedResultSet). + * @return CachedResultSet that captures the metadata and the rows of the input ResultSet. + * @throws SQLException + */ + public CachedResultSet(final ResultSet resultSet) throws SQLException { + ResultSetMetaData srcMetadata = resultSet.getMetaData(); + final int numColumns = srcMetadata.getColumnCount(); + CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[numColumns]; + for (int i = 0; i < numColumns; i++) { + fields[i] = new CachedResultSetMetaData.Field(srcMetadata, i+1); + } + metadata = new CachedResultSetMetaData(fields); + rows = new ArrayList<>(); + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(srcMetadata.getColumnLabel(i), i); + } + while (resultSet.next()) { + final CachedRow row = new CachedRow(numColumns); + for (int i = 1; i <= numColumns; ++i) { + Object rowObj = resultSet.getObject(i); + // For SQLXML object, convert into CachedSQLXML object that is serializable + if (rowObj instanceof SQLXML) { + rowObj = new CachedSQLXML(((SQLXML)rowObj).getString()); + } + row.put(i, rowObj); + } + rows.add(row); + } + currentRow = -1; + closed = false; + wasNullFlag = false; + } + + private CachedResultSet(final CachedResultSetMetaData md, final ArrayList resultRows) throws SQLException { + int numColumns = md.getColumnCount(); + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(md.getColumnLabel(i), i); + } + currentRow = -1; + rows = resultRows; + metadata = md; + closed = false; + wasNullFlag = false; + } + + // Serialize the content of metadata and data rows for the current CachedResultSet into a byte array + public byte[] serializeIntoByteArray() throws SQLException { + // Serialize the metadata and then the rows + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream output = new ObjectOutputStream(baos)) { + output.writeObject(metadata); + output.writeInt(rows.size()); + int numColumns = metadata.getColumnCount(); + while (this.next()) { + // serialize individual column fields in each row + CachedRow row = rows.get(currentRow); + for (int i = 0; i < numColumns; i++) { + try (ByteArrayOutputStream objBytes = new ByteArrayOutputStream(); + ObjectOutputStream objStream = new ObjectOutputStream(objBytes)) { + objStream.writeObject(row.get(i + 1)); + objStream.flush(); + byte[] dataByteArray = objBytes.toByteArray(); + int serializedLength = dataByteArray.length; + output.writeInt(serializedLength); + output.write(dataByteArray, 0, serializedLength); + } + } + } + output.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new SQLException("Error while serializing the ResultSet for caching: ", e); + } + } + + /** + * Form a ResultSet from the raw data from the cache server. Each of the column objects are stored as + * raw bytes and the actual de-serialization into Java objects will happen lazily upon access later on. + */ + public static ResultSet deserializeFromByteArray(byte[] data) throws SQLException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(data); + ObjectInputStream ois = new ObjectInputStream(bis)) { + CachedResultSetMetaData metadata = (CachedResultSetMetaData) ois.readObject(); + int numRows = ois.readInt(); + int numColumns = metadata.getColumnCount(); + ArrayList resultRows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + // Store the raw bytes for each column object in CachedRow + final CachedRow row = new CachedRow(numColumns); + for(int j = 0; j < numColumns; j++) { + int nextObjSize = ois.readInt(); // The size of the next serialized object in its raw bytes form + byte[] objData = new byte[nextObjSize]; + int lengthRead = 0; + while (lengthRead < nextObjSize) { + int bytesRead = ois.read(objData, lengthRead, nextObjSize-lengthRead); + if (bytesRead == -1) { + throw new SQLException("End of stream reached when reading the data for CachedResultSet"); + } + lengthRead += bytesRead; + } + row.putRaw(j+1, objData); + } + resultRows.add(row); + } + return new CachedResultSet(metadata, resultRows); + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing resultSet for caching", e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing resultSet for caching", e); + } + } + + @Override + public boolean next() throws SQLException { + if (rows.isEmpty()) return false; + if (this.currentRow >= rows.size() - 1) { + afterLast(); + return false; + } + currentRow++; + return true; + } + + @Override + public void close() throws SQLException { + currentRow = rows.size() - 1; + closed = true; + } + + @Override + public boolean wasNull() throws SQLException { + if (isClosed()) { + throw new SQLException("This result set is closed"); + } + return this.wasNullFlag; + } + + @Override + public String getString(final int columnIndex) throws SQLException { + Object value = checkAndGetColumnValue(columnIndex); + if (value == null) return null; + return value.toString(); + } + + @Override + public boolean getBoolean(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return false; + if (val instanceof Boolean) return (Boolean) val; + if (val instanceof Number) return ((Number) val).intValue() != 0; + return Boolean.parseBoolean(val.toString()); + } + + @Override + public byte getByte(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Byte) return (Byte) val; + if (val instanceof Number) return ((Number) val).byteValue(); + return Byte.parseByte(val.toString()); + } + + @Override + public short getShort(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Short) return (Short) val; + if (val instanceof Number) return ((Number) val).shortValue(); + return Short.parseShort(val.toString()); + } + + @Override + public int getInt(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Integer) return (Integer) val; + if (val instanceof Number) return ((Number) val).intValue(); + return Integer.parseInt(val.toString()); + } + + @Override + public long getLong(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Long) return (Long) val; + if (val instanceof Number) return ((Number) val).longValue(); + return Long.parseLong(val.toString()); + } + + @Override + public float getFloat(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Float) return (Float) val; + if (val instanceof Number) return ((Number) val).floatValue(); + return Float.parseFloat(val.toString()); + } + + @Override + public double getDouble(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Double) return (Double) val; + if (val instanceof Number) return ((Number) val).doubleValue(); + return Double.parseDouble(val.toString()); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return new BigDecimal(((Number)val).doubleValue()).setScale(scale, RoundingMode.HALF_UP); + return new BigDecimal(Double.parseDouble(val.toString())).setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public byte[] getBytes(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof byte[]) return (byte[]) val; + // Convert non-byte data to string, then to bytes (standard JDBC behavior) + return val.toString().getBytes(); + } + + private Date convertToDate(Object dateObj, Calendar cal) throws SQLException { + if (dateObj == null) return null; + if (dateObj instanceof Date) return (Date)dateObj; + if (dateObj instanceof Number) return new Date(((Number)dateObj).longValue()); + if (dateObj instanceof LocalDate) { + // Convert the LocalDate for the specified time zone into Date representing + // the same instant of time for the default time zone. + LocalDate localDate = (LocalDate)dateObj; + if (cal == null) return Date.valueOf(localDate); + LocalDateTime localDateTime = localDate.atStartOfDay(); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Date.valueOf(targetZonedDateTime.toLocalDate()); + } + if (dateObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) dateObj; + long millis = timestamp.getTime(); + if (cal == null) return new Date(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Date(adjustedMillis); + } + + // Note: normally the user should properly store the Date object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Date already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Date. We try to do a + // best-effort string parsing into Date with standard format "YYYY-MM-DD". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Date.valueOf(dateObj.toString()); + } + + @Override + public Date getDate(final int columnIndex) throws SQLException { + // The value cached is the string representation of epoch time in milliseconds + return convertToDate(checkAndGetColumnValue(columnIndex), null); + } + + private Time convertToTime(Object timeObj, Calendar cal) throws SQLException { + if (timeObj == null) return null; + if (timeObj instanceof Time) return (Time) timeObj; + if (timeObj instanceof Number) return new Time(((Number)timeObj).longValue()); // TODO: test + if (timeObj instanceof LocalTime) { + // Convert the LocalTime for the specified time zone into Time representing + // the same instant of time for the default time zone. + LocalTime localTime = (LocalTime)timeObj; + if (cal == null) return Time.valueOf(localTime); + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.now(), localTime); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Time.valueOf(targetZonedDateTime.toLocalTime()); + } + if (timeObj instanceof OffsetTime) { + OffsetTime offsetTime = (OffsetTime) timeObj; + if (cal == null) { + // Convert to default timezone using ZonedDateTime conversion + ZonedDateTime zonedDateTime = offsetTime.atDate(LocalDate.now()) + .atZoneSameInstant(defaultTimeZoneId); + return Time.valueOf(zonedDateTime.toLocalTime()); + } else { + // Convert to specified calendar timezone + ZonedDateTime zonedDateTime = offsetTime.atDate(LocalDate.now()) + .atZoneSameInstant(cal.getTimeZone().toZoneId()); + return Time.valueOf(zonedDateTime.toLocalTime()); + } + } + if (timeObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) timeObj; + long millis = timestamp.getTime(); + if (cal == null) return new Time(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Time(adjustedMillis); + } + + // Note: normally the user should properly store the Time object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Time already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Time. We try to do a + // best-effort string parsing into Time with standard format "HH:MM:SS". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Time.valueOf(timeObj.toString()); + } + + @Override + public Time getTime(final int columnIndex) throws SQLException { + return convertToTime(checkAndGetColumnValue(columnIndex), null); + } + + private Timestamp convertToTimestamp(Object timestampObj, Calendar calendar) { + if (timestampObj == null) return null; + if (timestampObj instanceof Timestamp) return (Timestamp) timestampObj; + if (timestampObj instanceof Number) return new Timestamp(((Number)timestampObj).longValue()); + if (timestampObj instanceof LocalDateTime) { + // Convert LocalDateTime based on the specified calendar time zone info into a + // Timestamp based on the JVM's default time zone representing the same instant + long epochTimeInMillis; + LocalDateTime localTime = (LocalDateTime)timestampObj; + if (calendar != null) { + epochTimeInMillis = localTime.atZone(calendar.getTimeZone().toZoneId()).toInstant().toEpochMilli(); + } else { + epochTimeInMillis = localTime.atZone(defaultTimeZoneId).toInstant().toEpochMilli(); + } + return new Timestamp(epochTimeInMillis); + } + if (timestampObj instanceof OffsetDateTime) { + return Timestamp.from(((OffsetDateTime)timestampObj).toInstant()); + } + if (timestampObj instanceof ZonedDateTime) { + return Timestamp.from(((ZonedDateTime)timestampObj).toInstant()); + } + + // Note: normally the user should properly store the Timestamp/DateTime object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Timestamp already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Timestamp. We try to do a + // best-effort string parsing into Timestamp with standard format "YYYY-MM-DD HH:MM:SS". The user is + // then expected to handle parsing failure and implement custom logic to fetch this as String. + return Timestamp.valueOf(timestampObj.toString()); + } + + @Override + public Timestamp getTimestamp(final int columnIndex) throws SQLException { + return convertToTimestamp(checkAndGetColumnValue(columnIndex), null); + } + + @Override + public InputStream getAsciiStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(final String columnLabel) throws SQLException { + return getString(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public boolean getBoolean(final String columnLabel) throws SQLException { + return getBoolean(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public byte getByte(final String columnLabel) throws SQLException { + return getByte(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public short getShort(final String columnLabel) throws SQLException { + return getShort(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public int getInt(final String columnLabel) throws SQLException { + return getInt(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public long getLong(final String columnLabel) throws SQLException { + return getLong(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public float getFloat(final String columnLabel) throws SQLException { + return getFloat(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public double getDouble(final String columnLabel) throws SQLException { + return getDouble(checkAndGetColumnIndex(columnLabel)); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { + return getBigDecimal(checkAndGetColumnIndex(columnLabel), scale); + } + + @Override + public byte[] getBytes(final String columnLabel) throws SQLException { + return getBytes(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Date getDate(final String columnLabel) throws SQLException { + return getDate(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Time getTime(final String columnLabel) throws SQLException { + return getTime(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public Timestamp getTimestamp(final String columnLabel) throws SQLException { + return getTimestamp(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public InputStream getAsciiStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return null; + } + + @Override + public void clearWarnings() throws SQLException { + // no-op + } + + @Override + public String getCursorName() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return metadata; + } + + private void checkCurrentRow() throws SQLException { + if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { + throw new SQLException("The current row index " + this.currentRow + " is out of range."); + } + } + + @Override + public Object getObject(final int columnIndex) throws SQLException { + checkCurrentRow(); + return checkAndGetColumnValue(columnIndex); + } + + @Override + public Object getObject(final String columnLabel) throws SQLException { + checkCurrentRow(); + return checkAndGetColumnValue(checkAndGetColumnIndex(columnLabel)); + } + + // Check the column index passed in is proper, and return the value of the column from the current row + private Object checkAndGetColumnValue(final int columnIndex) throws SQLException { + if (columnIndex == 0 || columnIndex > this.columnNames.size()) throw new SQLException("Column out of bounds"); + final CachedRow row = this.rows.get(this.currentRow); + final Object val = row.get(columnIndex); + this.wasNullFlag = (val == null); + return val; + } + + // Check column label exists and returns the column index corresponding to the column name + private int checkAndGetColumnIndex(final String columnLabel) throws SQLException { + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) throw new SQLException("Column not found: " + columnLabel); + return colIndex; + } + + @Override + public int findColumn(final String columnLabel) throws SQLException { + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) { + throw new SQLException("The column " + columnLabel + " is not found in this ResultSet."); + } + return colIndex; + } + + @Override + public Reader getCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return BigDecimal.valueOf(((Number) val).doubleValue()); + return new BigDecimal(val.toString()); + } + + @Override + public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { + return getBigDecimal(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public boolean isBeforeFirst() throws SQLException { + return this.currentRow < 0; + } + + @Override + public boolean isAfterLast() throws SQLException { + return this.currentRow >= this.rows.size(); + } + + @Override + public boolean isFirst() throws SQLException { + return this.currentRow == 0 && !this.rows.isEmpty(); + } + + @Override + public boolean isLast() throws SQLException { + return this.currentRow == (this.rows.size() - 1) && !this.rows.isEmpty(); + } + + @Override + public void beforeFirst() throws SQLException { + this.currentRow = -1; + } + + @Override + public void afterLast() throws SQLException { + this.currentRow = this.rows.size(); + } + + @Override + public boolean first() throws SQLException { + this.currentRow = 0; + return this.currentRow < this.rows.size(); + } + + @Override + public boolean last() throws SQLException { + this.currentRow = this.rows.size() - 1; + return this.currentRow >= 0; + } + + @Override + public int getRow() throws SQLException { + if (this.currentRow >= 0 && this.currentRow < this.rows.size()) { + return this.currentRow + 1; + } + return 0; + } + + @Override + public boolean absolute(final int row) throws SQLException { + if (row == 0) { + this.beforeFirst(); + return false; + } else { + int rowsSize = this.rows.size(); + if (row < 0) { + if (row < -rowsSize) { + this.beforeFirst(); + return false; + } + this.currentRow = rowsSize + row; + } else { // row > 0 + if (row > rowsSize) { + this.afterLast(); + return false; + } + this.currentRow = row - 1; + } + } + return true; + } + + @Override + public boolean relative(final int rows) throws SQLException { + this.currentRow += rows; + if (this.currentRow < 0) { + this.beforeFirst(); + return false; + } else if (this.currentRow >= this.rows.size()) { + this.afterLast(); + return false; + } + return true; + } + + @Override + public boolean previous() throws SQLException { + if (this.currentRow < 1) { + this.beforeFirst(); + return false; + } + this.currentRow--; + return true; + } + + @Override + public void setFetchDirection(final int direction) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchDirection() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchSize(final int rows) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchSize() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getType() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getConcurrency() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowUpdated() throws SQLException { + return false; + } + + @Override + public boolean rowInserted() throws SQLException { + return false; + } + + @Override + public boolean rowDeleted() throws SQLException { + return false; + } + + @Override + public void updateNull(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final int columnIndex, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final int columnIndex, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final int columnIndex, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final int columnIndex, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final int columnIndex, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final int columnIndex, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final int columnIndex, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final int columnIndex, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final int columnIndex, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final String columnLabel, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final String columnLabel, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final String columnLabel, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final String columnLabel, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final String columnLabel, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final String columnLabel, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final String columnLabel, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final String columnLabel, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final String columnLabel, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void insertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void refreshRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelRowUpdates() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToInsertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToCurrentRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Statement getStatement() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final int columnIndex, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final String columnLabel, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { + return convertToDate(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { + return getDate(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { + return convertToTime(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { + return getTime(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { + return convertToTimestamp(checkAndGetColumnValue(columnIndex), cal); + } + + @Override + public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { + return getTimestamp(checkAndGetColumnIndex(columnLabel), cal); + } + + @Override + public URL getURL(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof URL) return (URL) val; + try { + return new URL(val.toString()); + } catch (MalformedURLException e) { + throw new SQLException("Cannot extract url: " + val, e); + } + } + + @Override + public URL getURL(final String columnLabel) throws SQLException { + return getURL(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateRef(final int columnIndex, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(final String columnLabel, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final int columnIndex, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final String columnLabel, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof RowId) return (RowId) val; + throw new SQLException("Cannot extract rowId: " + val); + } + + @Override + public RowId getRowId(final String columnLabel) throws SQLException { + return getRowId(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateRowId(final int columnIndex, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(final String columnLabel, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getHoldability() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClosed() throws SQLException { + return closed; + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final int columnIndex, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final String columnLabel, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings("checkstyle:MethodName") + public NClob getNClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob getNClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(final int columnIndex) throws SQLException { + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof SQLXML) return (SQLXML) val; + return new CachedSQLXML(val.toString()); + } + + @Override + public SQLXML getSQLXML(final String columnLabel) throws SQLException { + return getSQLXML(checkAndGetColumnIndex(columnLabel)); + } + + @Override + public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(final int columnIndex) throws SQLException { + return getString(columnIndex); + } + + @Override + public String getNString(final String columnLabel) throws SQLException { + return getString(columnLabel); + } + + @Override + public Reader getNCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getObject(final int columnIndex, final Class type) throws SQLException { + return type.cast(getObject(columnIndex)); + } + + @Override + public T getObject(final String columnLabel, final Class type) throws SQLException { + return type.cast(getObject(columnLabel)); + } + + @Override + public T unwrap(final Class iface) throws SQLException { + if (iface.isAssignableFrom(this.getClass())) { + return iface.cast(this); + } else { + throw new SQLException("Cannot unwrap to " + iface.getName()); + } + } + + @Override + public boolean isWrapperFor(final Class iface) throws SQLException { + return iface != null && iface.isAssignableFrom(this.getClass()); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java new file mode 100644 index 000000000..bf295cb6b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java @@ -0,0 +1,180 @@ +package software.amazon.jdbc.plugin.cache; + +import java.io.Serializable; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; + +class CachedResultSetMetaData implements ResultSetMetaData, Serializable { + protected final Field[] columns; + + protected static class Field implements Serializable { + String catalog; + String className; + String label; + String name; + String typeName; + int type; + int displaySize; + int precision; + String tableName; + int scale; + String schemaName; + boolean isAutoIncrement; + boolean isCaseSensitive; + boolean isCurrency; + boolean isDefinitelyWritable; + int isNullable; + boolean isReadOnly; + boolean isSearchable; + boolean isSigned; + boolean isWritable; + + protected Field(final ResultSetMetaData srcMetadata, int column) throws SQLException { + catalog = srcMetadata.getCatalogName(column); + className = srcMetadata.getColumnClassName(column); + label = srcMetadata.getColumnLabel(column); + name = srcMetadata.getColumnName(column); + typeName = srcMetadata.getColumnTypeName(column); + type = srcMetadata.getColumnType(column); + displaySize = srcMetadata.getColumnDisplaySize(column); + precision = srcMetadata.getPrecision(column); + tableName = srcMetadata.getTableName(column); + scale = srcMetadata.getScale(column); + schemaName = srcMetadata.getSchemaName(column); + isAutoIncrement = srcMetadata.isAutoIncrement(column); + isCaseSensitive = srcMetadata.isCaseSensitive(column); + isCurrency = srcMetadata.isCurrency(column); + isDefinitelyWritable = srcMetadata.isDefinitelyWritable(column); + isNullable = srcMetadata.isNullable(column); + isReadOnly = srcMetadata.isReadOnly(column); + isSearchable = srcMetadata.isSearchable(column); + isSigned = srcMetadata.isSigned(column); + isWritable = srcMetadata.isWritable(column); + } + } + + CachedResultSetMetaData(Field[] columns) { + this.columns = columns; + } + + @Override + public int getColumnCount() throws SQLException { + return columns.length; + } + + private Field getColumns(final int column) throws SQLException { + if (column == 0 || column > columns.length) + throw new SQLException("Wrong column number: " + column); + return columns[column - 1]; + } + + @Override + public boolean isAutoIncrement(int column) throws SQLException { + return getColumns(column).isAutoIncrement; + } + + @Override + public boolean isCaseSensitive(int column) throws SQLException { + return getColumns(column).isCaseSensitive; + } + + @Override + public boolean isSearchable(int column) throws SQLException { + return getColumns(column).isSearchable; + } + + @Override + public boolean isCurrency(int column) throws SQLException { + return getColumns(column).isCurrency; + } + + @Override + public int isNullable(int column) throws SQLException { + return getColumns(column).isNullable; + } + + @Override + public boolean isSigned(int column) throws SQLException { + return getColumns(column).isSigned; + } + + @Override + public int getColumnDisplaySize(int column) throws SQLException { + return getColumns(column).displaySize; + } + + @Override + public String getColumnLabel(int column) throws SQLException { + return getColumns(column).label; + } + + @Override + public String getColumnName(int column) throws SQLException { + return getColumns(column).name; + } + + @Override + public String getSchemaName(int column) throws SQLException { + return getColumns(column).schemaName; + } + + @Override + public int getPrecision(int column) throws SQLException { + return getColumns(column).precision; + } + + @Override + public int getScale(int column) throws SQLException { + return getColumns(column).scale; + } + + @Override + public String getTableName(int column) throws SQLException { + return getColumns(column).tableName; + } + + @Override + public String getCatalogName(int column) throws SQLException { + return getColumns(column).catalog; + } + + @Override + public int getColumnType(int column) throws SQLException { + return getColumns(column).type; + } + + @Override + public String getColumnTypeName(int column) throws SQLException { + return getColumns(column).typeName; + } + + @Override + public boolean isReadOnly(int column) throws SQLException { + return getColumns(column).isReadOnly; + } + + @Override + public boolean isWritable(int column) throws SQLException { + return getColumns(column).isWritable; + } + + @Override + public boolean isDefinitelyWritable(int column) throws SQLException { + return getColumns(column).isDefinitelyWritable; + } + + @Override + public String getColumnClassName(int column) throws SQLException { + return getColumns(column).className; + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java new file mode 100644 index 000000000..a49240172 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java @@ -0,0 +1,118 @@ +package software.amazon.jdbc.plugin.cache; + +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.XMLReaderFactory; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.io.Writer; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.sql.SQLXML; + +public class CachedSQLXML implements SQLXML, Serializable { + private boolean freed; + private String data; + + public CachedSQLXML(String data) { + this.data = data; + this.freed = false; + } + + @Override + public void free() throws SQLException { + if (this.freed) return; + this.data = null; + this.freed = true; + } + + private void checkFreed() throws SQLException { + if (this.freed) { + throw new SQLException("This SQLXML object has already been freed."); + } + } + + @Override + public InputStream getBinaryStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new ByteArrayInputStream(this.data.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public OutputStream setBinaryStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new StringReader(this.data); + } + + @Override + public Writer setCharacterStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString() throws SQLException { + checkFreed(); + return this.data; + } + + @Override + public void setString(String value) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getSource(Class sourceClass) throws SQLException { + checkFreed(); + if (this.data == null) return null; + + try { + if (sourceClass == null || DOMSource.class.equals(sourceClass)) { + DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + return (T) new DOMSource(builder.parse(new InputSource(new StringReader(data)))); + } + + if (SAXSource.class.equals(sourceClass)) { + XMLReader reader = XMLReaderFactory.createXMLReader(); + return sourceClass.cast(new SAXSource(reader, new InputSource(new StringReader(data)))); + } + + if (StreamSource.class.equals(sourceClass)) { + return sourceClass.cast(new StreamSource(new StringReader(data))); + } + + if (StAXSource.class.equals(sourceClass)) { + XMLStreamReader xsr = XMLInputFactory.newFactory().createXMLStreamReader(new StringReader(data)); + return sourceClass.cast(new StAXSource(xsr)); + } + throw new SQLException("Unsupported source class for XML data: " + sourceClass.getName()); + } catch (Exception e) { + throw new SQLException("Unable to decode XML data.", e); + } + } + + @Override + public T setResult(Class resultClass) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java new file mode 100644 index 000000000..ac3d505e1 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public final class CachedSupplier { + + private CachedSupplier() { + throw new UnsupportedOperationException("Utility class should not be instantiated"); + } + + public static Supplier memoizeWithExpiration( + Supplier delegate, long duration, TimeUnit unit) { + + Objects.requireNonNull(delegate, "delegate Supplier must not be null"); + Objects.requireNonNull(unit, "TimeUnit must not be null"); + if (duration <= 0) { + throw new IllegalArgumentException("duration must be > 0"); + } + + return new ExpiringMemoizingSupplier<>(delegate, duration, unit); + } + + private static final class ExpiringMemoizingSupplier implements Supplier { + + private final Supplier delegate; + private final long durationNanos; + private final ReentrantLock lock = new ReentrantLock(); + + private volatile T value; + private volatile long expirationNanos; // 0 means not yet initialized + + ExpiringMemoizingSupplier(Supplier delegate, long duration, TimeUnit unit) { + this.delegate = delegate; + this.durationNanos = unit.toNanos(duration); + } + + @Override + public T get() { + long now = System.nanoTime(); + + // Check if value is expired or uninitialized + if (expirationNanos == 0 || now - expirationNanos >= 0) { + lock.lock(); + try { + if (expirationNanos == 0 || now - expirationNanos >= 0) { + value = delegate.get(); + long next = now + durationNanos; + expirationNanos = (next == 0) ? 1 : next; // avoid 0 sentinel + } + } finally { + lock.unlock(); + } + } + return value; + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java new file mode 100644 index 000000000..3d4743fea --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java @@ -0,0 +1,167 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.logging.Logger; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +public class DataLocalCacheConnectionPlugin extends AbstractConnectionPlugin { + + private static final Logger LOGGER = Logger.getLogger(DataLocalCacheConnectionPlugin.class.getName()); + + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( + Arrays.asList( + JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.STATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName + ))); + + public static final AwsWrapperProperty DATA_CACHE_TRIGGER_CONDITION = new AwsWrapperProperty( + "dataCacheTriggerCondition", "false", + "A regular expression that, if it's matched, allows the plugin to cache SQL results."); + + protected static final Map dataCache = new ConcurrentHashMap<>(); + + protected final String dataCacheTriggerCondition; + + static { + PropertyDefinition.registerPluginProperties(DataLocalCacheConnectionPlugin.class); + } + + private final TelemetryFactory telemetryFactory; + private final TelemetryCounter hitCounter; + private final TelemetryCounter missCounter; + private final TelemetryCounter totalCallsCounter; + private final TelemetryGauge cacheSizeGauge; + + public DataLocalCacheConnectionPlugin(final PluginService pluginService, final Properties props) { + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.dataCacheTriggerCondition = DATA_CACHE_TRIGGER_CONDITION.getString(props); + + this.hitCounter = telemetryFactory.createCounter("dataCache.cache.hit"); + this.missCounter = telemetryFactory.createCounter("dataCache.cache.miss"); + this.totalCallsCounter = telemetryFactory.createCounter("dataCache.cache.totalCalls"); + this.cacheSizeGauge = telemetryFactory.createGauge("dataCache.cache.size", () -> (long) dataCache.size()); + } + + public static void clearCache() { + dataCache.clear(); + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + + if (StringUtils.isNullOrEmpty(this.dataCacheTriggerCondition) || resultClass != ResultSet.class) { + return jdbcMethodFunc.call(); + } + + if (this.totalCallsCounter != null) { + this.totalCallsCounter.inc(); + } + + ResultSet result; + boolean needToCache = false; + final String sql = getQuery(jdbcMethodArgs); + + if (!StringUtils.isNullOrEmpty(sql) && sql.matches(this.dataCacheTriggerCondition)) { + result = dataCache.get(sql); + if (result == null) { + needToCache = true; + if (this.missCounter != null) { + this.missCounter.inc(); + } + LOGGER.finest( + () -> Messages.get( + "DataLocalCacheConnectionPlugin.queryResultsCached", + new Object[]{methodName, sql})); + } else { + if (this.hitCounter != null) { + this.hitCounter.inc(); + } + try { + result.beforeFirst(); + } catch (final SQLException ex) { + if (exceptionClass.isAssignableFrom(ex.getClass())) { + throw exceptionClass.cast(ex); + } + throw new RuntimeException(ex); + } + return resultClass.cast(result); + } + } + + result = (ResultSet) jdbcMethodFunc.call(); + + if (needToCache) { + final ResultSet cachedResultSet; + try { + cachedResultSet = new CachedResultSet(result); + dataCache.put(sql, cachedResultSet); + cachedResultSet.beforeFirst(); + return resultClass.cast(cachedResultSet); + } catch (final SQLException ex) { + // ignore exception + } + } + + return resultClass.cast(result); + } + + protected String getQuery(final Object[] jdbcMethodArgs) { + + // Get query from method argument + if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { + return jdbcMethodArgs[0].toString(); + } + return null; + } + +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java new file mode 100644 index 000000000..c28e03e89 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java @@ -0,0 +1,30 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; + +public class DataLocalCacheConnectionPluginFactory implements ConnectionPluginFactory { + + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + return new DataLocalCacheConnectionPlugin(pluginService, props); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java new file mode 100644 index 000000000..e79d8de84 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -0,0 +1,375 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class DataRemoteCachePlugin extends AbstractConnectionPlugin { + private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); + private static final String QUERY_HINT_START_PATTERN = "/*"; + private static final String QUERY_HINT_END_PATTERN = "*/"; + private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; + private static final String TELEMETRY_CACHE_LOOKUP = "jdbc-cache-lookup"; + private static final String TELEMETRY_DATABASE_QUERY = "jdbc-database-query"; + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( + Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.STATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName))); + + private int maxCacheableQuerySize; + private PluginService pluginService; + private TelemetryFactory telemetryFactory; + private TelemetryCounter cacheHitCounter; + private TelemetryCounter cacheMissCounter; + private TelemetryCounter totalQueryCounter; + private TelemetryCounter malformedHintCounter; + private TelemetryCounter cacheBypassCounter; + private CacheConnection cacheConnection; + private String dbUserName; + + private static final AwsWrapperProperty CACHE_MAX_QUERY_SIZE = + new AwsWrapperProperty( + "cacheMaxQuerySize", + "16384", + "The max query size for remote caching"); + + static { + PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); + } + + public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { + try { + Class.forName("io.lettuce.core.RedisClient"); // Lettuce dependency + Class.forName("org.apache.commons.pool2.impl.GenericObjectPool"); // Object pool dependency + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("DataRemoteCachePlugin.notInClassPath", new Object[] {e.getMessage()})); + } + this.pluginService = pluginService; + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.cacheHitCounter = telemetryFactory.createCounter("JdbcCachedQueryCount"); + this.cacheMissCounter = telemetryFactory.createCounter("JdbcCacheMissCount"); + this.totalQueryCounter = telemetryFactory.createCounter("JdbcCacheTotalQueryCount"); + this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint"); + this.cacheBypassCounter = telemetryFactory.createCounter("JdbcCacheBypassCount"); + this.maxCacheableQuerySize = CACHE_MAX_QUERY_SIZE.getInteger(properties); + this.cacheConnection = new CacheConnection(properties, this.telemetryFactory); + this.dbUserName = PropertyDefinition.USER.getString(properties); + } + + // Used for unit testing purposes only + protected void setCacheConnection(CacheConnection conn) { + this.cacheConnection = conn; + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + private String getCacheQueryKey(String query) { + // Check some basic session states. The important ones for caching include (but not limited to): + // schema name, username which can affect the query result from the DB in addition to the query string + try { + Connection currentConn = pluginService.getCurrentConnection(); + DatabaseMetaData metadata = currentConn.getMetaData(); + // Fetch and record the schema name if the session state doesn't currently have it + SessionStateService sessionStateService = pluginService.getSessionStateService(); + String catalog = sessionStateService.getCatalog().orElse(null); + String schema = sessionStateService.getSchema().orElse(null); + if (catalog == null && schema == null) { + // Fetch the current schema name and store it in sessionStateService + catalog = currentConn.getCatalog(); + schema = currentConn.getSchema(); + if (catalog != null) sessionStateService.setCatalog(catalog); + if (schema != null) sessionStateService.setSchema(schema); + } + + if (dbUserName == null) { + // For MySQL, metadata username is actually @. We just need the part before '@'. + dbUserName = metadata.getUserName(); + int nameIndexEnd = dbUserName.indexOf('@'); + if (nameIndexEnd > 0) { + dbUserName = dbUserName.substring(0, nameIndexEnd); + } + } + LOGGER.finest("DB driver protocol " + pluginService.getDriverProtocol() + + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() + + ", catalog: " + catalog + ", schema: " + schema + ", user: " + dbUserName + + ", driver: " + metadata.getDriverName() + " " + metadata.getDriverVersion()); + // The cache key contains the schema name, user name, and the query string + String[] words = {catalog, schema, dbUserName, query}; + return String.join("_", words); + } catch (SQLException e) { + LOGGER.warning("Error getting session state: " + e.getMessage()); + return null; + } + } + + private ResultSet fetchResultSetFromCache(String queryStr) throws SQLException { + if (cacheConnection == null) return null; + + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return null; // Treat this as a cache miss + byte[] cachedResult = cacheConnection.readFromCache(cacheQueryKey); + if (cachedResult == null) return null; + // Convert result into ResultSet + try { + return CachedResultSet.deserializeFromByteArray(cachedResult); + } catch (Exception e) { + LOGGER.warning("Error de-serializing cached result: " + e.getMessage()); + return null; // Treat this as a cache miss + } + } + + /** + * Cache the given ResultSet object. + * The ResultSet object passed in would be consumed to create a CacheResultSet object. It is returned + * for consumer consumption. + */ + private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) throws SQLException { + // Write the resultSet into the cache as a single key + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return rs; // Treat this condition as un-cacheable + CachedResultSet crs = new CachedResultSet(rs); + byte[] jsonString = crs.serializeIntoByteArray(); + cacheConnection.writeToCache(cacheQueryKey, jsonString, expiry); + crs.beforeFirst(); + return crs; + } + + /** + * Determine the TTL based on an input query + * @param queryHint string. e.g. "CACHE_PARAM(ttl=100s, key=custom)" + * @return TTL in seconds to cache the query. + * null if the query is not cacheable. + */ + protected Integer getTtlForQuery(String queryHint) { + // Empty query is not cacheable + if (StringUtils.isNullOrEmpty(queryHint)) return null; + // Find CACHE_PARAM anywhere in the hint string (case insensitive) + String upperHint = queryHint.toUpperCase(); + int cacheParamStart = upperHint.indexOf(CACHE_PARAM_PATTERN); + if (cacheParamStart == -1) return null; + + // Find the matching closing parenthesis + int paramsStart = cacheParamStart + CACHE_PARAM_PATTERN.length(); + int paramsEnd = upperHint.indexOf(")", paramsStart); + if (paramsEnd == -1) return null; + + // Extract parameters between parentheses + String cacheParams = upperHint.substring(paramsStart, paramsEnd).trim(); + // Empty parameters + if (StringUtils.isNullOrEmpty(cacheParams)) { + LOGGER.warning("Empty CACHE_PARAM parameters"); + incrCounter(malformedHintCounter); + return null; + } + + // Parse comma-separated parameters + String[] params = cacheParams.split(","); + Integer ttlValue = null; + + for (String param : params) { + String[] keyValue = param.trim().split("="); + if (keyValue.length != 2) { + LOGGER.warning("Invalid caching parameter format: " + param); + incrCounter(malformedHintCounter); + return null; + } + String key = keyValue[0].trim(); + String value = keyValue[1].trim(); + + if ("TTL".equals(key)) { + if (!value.endsWith("S")) { + LOGGER.warning("TTL must end with 's': " + value); + incrCounter(malformedHintCounter); + return null; + } else{ + // Parse TTL value (e.g., "300s") + try { + ttlValue = Integer.parseInt(value.substring(0, value.length() - 1)); + // treat negative and 0 ttls as not cacheable + if (ttlValue <= 0) { + return null; + } + } catch (NumberFormatException e) { + LOGGER.warning(String.format("Invalid TTL format of %s for query %s", value, queryHint)); + incrCounter(malformedHintCounter); + return null; + } + } + } + } + return ttlValue; + } + + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + + if (resultClass != ResultSet.class) { + return jdbcMethodFunc.call(); + } + + incrCounter(totalQueryCounter); + + ResultSet result; + boolean needToCache = false; + final String sql = getQuery(methodInvokeOn, jdbcMethodArgs); + + TelemetryContext cacheContext = null; + TelemetryContext dbContext = null; + // If the query is cacheable, we try to fetch the query result from the cache. + boolean isInTransaction = pluginService.isInTransaction(); + // Get the query hint part in front of the query itself + String mainQuery = sql; // The main part of the query with the query hint prefix trimmed + int endOfQueryHint = 0; + Integer configuredQueryTtl = null; + // Queries longer than 16KB is not cacheable + if (!StringUtils.isNullOrEmpty(sql) && (sql.length() < maxCacheableQuerySize) && sql.contains(QUERY_HINT_START_PATTERN)) { + endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); + if (endOfQueryHint > 0) { + configuredQueryTtl = getTtlForQuery(sql.substring(QUERY_HINT_START_PATTERN.length(), endOfQueryHint).trim()); + mainQuery = sql.substring(endOfQueryHint + QUERY_HINT_END_PATTERN.length()).trim(); + } + } + + // Query result can be served from the cache if it has a configured TTL value, and it is + // not executed in a transaction as a transaction typically need to return consistent results. + if (!isInTransaction && (configuredQueryTtl != null)) { + cacheContext = telemetryFactory.openTelemetryContext( + TELEMETRY_CACHE_LOOKUP, TelemetryTraceLevel.TOP_LEVEL); + Exception cacheException = null; + try { + result = fetchResultSetFromCache(mainQuery); + if (result == null) { + // Cache miss. Need to fetch result from the database + needToCache = true; + incrCounter(cacheMissCounter); + LOGGER.finest("Got a cache miss for SQL: " + sql); + } else { + LOGGER.finest("Got a cache hit for SQL: " + sql); + // Cache hit. Return the cached result + incrCounter(cacheHitCounter); + result.beforeFirst(); + return resultClass.cast(result); + } + } catch (final SQLException ex) { + // SQLException from readFromCache (failWhenCacheDown=true) or result.beforeFirst() + cacheException = ex; + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + } finally { + if (cacheContext != null) { + if (cacheException != null) { + cacheContext.setSuccess(false); + cacheContext.setException(cacheException); + cacheContext.closeContext(); + } else if (!needToCache) { // Cache hit + cacheContext.setSuccess(true); + cacheContext.closeContext(); + } else { // Cache miss - leave context open + cacheContext.setSuccess(false); + } + } + } + } else { + incrCounter(cacheBypassCounter); + } + + dbContext = telemetryFactory.openTelemetryContext( + TELEMETRY_DATABASE_QUERY, TelemetryTraceLevel.TOP_LEVEL); + + try { + result = (ResultSet) jdbcMethodFunc.call(); + } finally { + if (dbContext != null) dbContext.closeContext(); + if (cacheContext != null) cacheContext.closeContext(); + } + + // We need to cache the query result if we got a cache miss for the query result, + // or the query is cacheable and executed inside a transaction. + if (isInTransaction && (configuredQueryTtl != null)) { + needToCache = true; + } + if (needToCache) { + try { + result = cacheResultSet(mainQuery, result, configuredQueryTtl); + } catch (final SQLException ex) { + // Log and re-throw exception + LOGGER.warning("Encountered SQLException when caching query results: " + ex.getMessage()); + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + } + } + + return resultClass.cast(result); + } + + private void incrCounter(TelemetryCounter counter) { + if (counter == null) return; + counter.inc(); + } + + protected String getQuery(final Object methodInvokeOn, final Object[] jdbcMethodArgs) { + // Get query from method argument + if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { + return jdbcMethodArgs[0].toString().trim(); + } + + // If the query is not in the method arguments, check for prepared statement query. Get the query + // string from the prepared statement. The exact query string is dependent on the underlying driver. + if (methodInvokeOn instanceof PreparedStatement) { + // For postgres, this gives the raw query itself. i.e. "select * from T where A = 1". + // For MySQL, this gives "com.mysql.cj.jdbc.ClientPreparedStatement: select * from T where A = 1" + // For MariaDB, this gives "ClientPreparedStatement{sql:'select * from T where A=1', parameters:[]}" + return methodInvokeOn.toString(); + } + return null; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java similarity index 83% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java index 555ed55cf..fb15d69c5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java @@ -14,17 +14,17 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.util.Properties; import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.ConnectionPluginFactory; import software.amazon.jdbc.PluginService; -public class DataCacheConnectionPluginFactory implements ConnectionPluginFactory { +public class DataRemoteCachePluginFactory implements ConnectionPluginFactory { @Override public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { - return new DataCacheConnectionPlugin(pluginService, props); + return new DataRemoteCachePlugin(pluginService, props); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java new file mode 100644 index 000000000..12d96825a --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java @@ -0,0 +1,125 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Objects; +import java.util.logging.Logger; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.CredentialUtils; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.StringUtils; + +public class ElastiCacheIamTokenUtility implements IamTokenUtility { + + private static final Logger LOGGER = Logger.getLogger(ElastiCacheIamTokenUtility.class.getName()); + private static final String PARAM_ACTION = "Action"; + private static final String PARAM_USER = "User"; + private static final String ACTION_NAME = "connect"; + private static final String PARAM_RESOURCE_TYPE = "ResourceType"; + private static final String RESOURCE_TYPE_SERVERLESS_CACHE = "ServerlessCache"; + private static final String SERVICE_NAME = "elasticache"; + private static final String PROTOCOL = "http"; + private static final Duration EXPIRATION_DURATION = Duration.ofSeconds(15 * 60 - 30); + public static final String SERVERLESS_CACHE_IDENTIFIER = ".serverless."; + + private final Clock clock; + private String cacheName = null; + private final Aws4Signer signer; + + public ElastiCacheIamTokenUtility(String cacheName) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.systemUTC(); + this.signer = Aws4Signer.create(); + } + + // For testing only + public ElastiCacheIamTokenUtility(String cacheName, Instant fixedInstant, Aws4Signer signer) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.fixed(fixedInstant, ZoneId.of("UTC")); + this.signer = signer; + } + + @Override + public String generateAuthenticationToken( + final @NonNull AwsCredentialsProvider credentialsProvider, + final @NonNull Region region, + final @NonNull String hostname, + final int port, + final @NonNull String username) { + + boolean isServerless = isServerlessCache(hostname); + if (this.cacheName == null) { + throw new IllegalArgumentException("Cache name cannot be null for cache with IAM authentication"); + } + + SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol(PROTOCOL) // ElastiCache uses http, not https + .host(this.cacheName) + .encodedPath("/") + .putRawQueryParameter(PARAM_ACTION, ACTION_NAME) + .putRawQueryParameter(PARAM_USER, username); + + if (isServerless) { + requestBuilder.putRawQueryParameter(PARAM_RESOURCE_TYPE, RESOURCE_TYPE_SERVERLESS_CACHE); + } + + final SdkHttpFullRequest httpRequest = requestBuilder.build(); + + final Instant expirationTime = Instant.now(this.clock).plus(EXPIRATION_DURATION); + + final AwsCredentials credentials = CredentialUtils.toCredentials( + CompletableFutureUtils.joinLikeSync(credentialsProvider.resolveIdentity())); + + final Aws4PresignerParams presignRequest = Aws4PresignerParams.builder() + .signingClockOverride(this.clock) + .expirationTime(expirationTime) + .awsCredentials(credentials) + .signingName(SERVICE_NAME) + .signingRegion(region) + .build(); + + final SdkHttpFullRequest fullRequest = this.signer.presign(httpRequest, presignRequest); + final String signedUrl = fullRequest.getUri().toString(); + + // Format should be: + // Regular: /?Action=connect&User=&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Serverless: /?Action=connect&User=&ResourceType=ServerlessCache&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Note: This must be the real ElastiCache hostname, not proxy or tunnels + final String result = StringUtils.replacePrefixIgnoreCase(signedUrl, "http://", ""); + LOGGER.finest(() -> "Generated ElastiCache authentication token with expiration of " + expirationTime); + return result; + } + + private boolean isServerlessCache(String hostname) { + if (hostname == null) { + throw new IllegalArgumentException("Hostname cannot be null"); + } + return hostname.contains(SERVERLESS_CACHE_IDENTIFIER); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java index 0037bdbd9..837a90d5c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java @@ -573,7 +573,7 @@ public static Connection getConnectionFromSqlObject(final Object obj) { } } catch (final SQLException | UnsupportedOperationException e) { // Do nothing. The UnsupportedOperationException comes from ResultSets returned by - // DataCacheConnectionPlugin and will be triggered when getStatement is called. + // DataLocalCacheConnectionPlugin and will be triggered when getStatement is called. } return null; diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index c5ea7c275..94afe5a1f 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -126,8 +126,12 @@ CustomEndpointPlugin.waitingForCustomEndpointInfo=Custom endpoint info for ''{0} CustomEndpointPluginFactory.awsSdkNotInClasspath=Required dependency 'AWS Java SDK RDS v2.x' is not on the classpath. -DataCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +DataLocalCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +# Data Remote Cache Plugin +DataRemoteCachePlugin.notInClassPath=Required dependency for DataRemoteCachePlugin is not on the classpath: ''{0}'' + +# Default Connection Plugin DefaultConnectionPlugin.executingMethod=Executing method: ''{0}'' DefaultConnectionPlugin.noHostsAvailable=The default connection plugin received an empty host list from the plugin service. DefaultConnectionPlugin.unknownRoleRequested=A HostSpec with a role of HostRole.UNKNOWN was requested via getHostSpecByStrategy. The requested role must be either HostRole.WRITER or HostRole.READER diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 3d26591e6..3e2f17251 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -51,6 +51,7 @@ dependencies { testImplementation("org.testcontainers:mariadb:1.20.4") testImplementation("org.testcontainers:junit-jupiter:1.20.4") testImplementation("org.testcontainers:toxiproxy:1.20.4") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.apache.poi:poi-ooxml:5.3.0") testImplementation("org.slf4j:slf4j-simple:2.0.13") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.17.1") @@ -58,6 +59,7 @@ dependencies { testImplementation("io.opentelemetry:opentelemetry-sdk:1.42.1") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.43.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.44.1") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") diff --git a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java index dc8f6afa3..d6b77fee5 100644 --- a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java +++ b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java @@ -39,8 +39,8 @@ import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin.CachedResultSet; +import software.amazon.jdbc.plugin.cache.CachedResultSet; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; @TestMethodOrder(MethodOrderer.MethodName.class) @ExtendWith(TestDriverProvider.class) @@ -58,20 +58,20 @@ public class DataCachePluginTests { @BeforeEach public void beforeEach() { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); } @TestTemplate public void testQueryCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); @@ -174,14 +174,14 @@ private void printTable() { @TestTemplate public void testQueryNotCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); props.setProperty( - DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); + DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java new file mode 100644 index 000000000..5423ddad2 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java @@ -0,0 +1,785 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import io.lettuce.core.RedisFuture; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulConnection; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; +import java.lang.reflect.Field; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.time.Duration; +import java.util.function.BiConsumer; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class CacheConnectionTest { + @Mock GenericObjectPool> mockReadConnPool; + @Mock GenericObjectPool> mockWriteConnPool; + @Mock StatefulRedisConnection mockConnection; + @Mock RedisCommands mockSyncCommands; + @Mock RedisAsyncCommands mockAsyncCommands; + @Mock StatefulRedisClusterConnection mockClusterConnection; + @Mock RedisAdvancedClusterCommands mockClusterSyncCommands; + @Mock RedisAdvancedClusterAsyncCommands mockClusterAsyncCommands; + @Mock RedisFuture mockCacheResult; + private AutoCloseable closeable; + private CacheConnection cacheConnection; + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + cacheConnection = new CacheConnection(props); + cacheConnection.setConnectionPools(mockReadConnPool, mockWriteConnPool); + // Bypass cluster detection for tests to avoid real Redis connections + cacheConnection.setClusterMode(false); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void testIamAuth_PropertyExtraction() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("us-west-2", getField(connection, "cacheIamRegion")); + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testIamAuth_PropertyExtractionTraditional() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "password"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("password", getField(connection, "cachePassword")); + } + + @Test + void testIamAuthEnabled_WhenRegionProvided() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Use reflection to verify iamAuthEnabled is true + Field field = CacheConnection.class.getDeclaredField("iamAuthEnabled"); + field.setAccessible(true); + assertTrue((boolean) field.get(connection)); + // Verify all IAM fields are set correctly + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheName() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + // Missing cacheName property + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("IAM authentication requires cache name, username, region, and hostname")); + } + + @Test + void testTraditionalAuth_WhenNoIamRegion() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + + assertFalse((boolean) getField(connection, "iamAuthEnabled")); + assertNull(getField(connection, "credentialsProvider")); + assertEquals("user", getField(connection, "cacheUsername")); + assertEquals("pass", getField(connection, "cachePassword")); + } + + @Test + void testConstructor_NoRwAddress() { + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRo", "localhost:6379"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheUsername() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_ConflictingAuthenticationMethods() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); // IAM auth + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "mypassword"); // Traditional auth + props.setProperty("cacheName", "my-cache"); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("Cannot specify both IAM authentication")); + } + + @Test + void testAwsCredentialsProvider_WithProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + props.setProperty("awsProfile", "test-profile"); + + CacheConnection connection = new CacheConnection(props); + + // Verify the awsProfileProperties field contains the correct profile + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertEquals("test-profile", awsProfileProps.getProperty("awsProfile")); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("localhost:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testAwsCredentialsProvider_WithoutProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + // No awsProfile property + + CacheConnection connection = new CacheConnection(props); + + // Verify awsProfileProperties is not empty when no profile specified + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertNull(awsProfileProps); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("localhost:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testBuildRedisURI_IamAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "test-cache"); + + try (MockedConstruction mockedTokenUtility = mockConstruction(ElastiCacheIamTokenUtility.class)) { + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + // Verify URI properties + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); + + // Trigger the credentials provider to create the token utility + uri.getCredentialsProvider().resolveCredentials().block(); + + // Verify URI properties + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); // IAM credentials provider set + + // Verify ElastiCacheIamTokenUtility was constructed with correct parameters + // Verify token utility construction + assertEquals(1, mockedTokenUtility.constructed().size()); + ElastiCacheIamTokenUtility tokenUtility = mockedTokenUtility.constructed().get(0); + verify(tokenUtility).generateAuthenticationToken( + any(AwsCredentialsProvider.class), + eq(Region.US_EAST_1), + eq("localhost"), + eq(6379), + eq("testuser") + ); + } + } + + @Test + void testBuildRedisURI_TraditionalAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertEquals("user", uri.getUsername()); + assertEquals("pass", new String(uri.getPassword())); + } + + @Test + void testBuildRedisURI_NoAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertNull(uri.getUsername()); + assertNull(uri.getPassword()); + } + + @Test + void test_writeToCache() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + doNothing().when(spyConnection).incrementInFlightSize(anyLong()); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + spyConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockCacheResult).whenComplete(any(BiConsumer.class)); + } + + @Test + void test_writeToCacheException() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + doNothing().when(spyConnection).incrementInFlightSize(anyLong()); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("test exception")); + spyConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockWriteConnPool).invalidateObject(mockConnection); + } + + @Test + void testHandleCompletedCacheWrite() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + + // Success: decrement called, no error reported, connection returned + spyConnection.handleCompletedCacheWrite(mockConnection, 150L, null); + verify(spyConnection).decrementInFlightSize(150L); + verify(spyConnection, never()).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + verify(mockWriteConnPool).returnObject(mockConnection); + + // Failure: decrement called, error reported, connection invalidated + RuntimeException writeError = new RuntimeException("Redis timeout"); + spyConnection.handleCompletedCacheWrite(mockConnection, 200L, writeError); + verify(spyConnection).decrementInFlightSize(200L); + verify(spyConnection).reportErrorToCacheMonitor(eq(true), eq(writeError), eq("WRITE")); + verify(mockWriteConnPool).invalidateObject(mockConnection); + + // Multiple operations: mixed success/failure + spyConnection.handleCompletedCacheWrite(mockConnection, 100L, null); + spyConnection.handleCompletedCacheWrite(mockConnection, 250L, new RuntimeException("lost")); + verify(spyConnection).decrementInFlightSize(100L); + verify(spyConnection).decrementInFlightSize(250L); + verify(mockWriteConnPool, times(2)).returnObject(mockConnection); + verify(mockWriteConnPool, times(2)).invalidateObject(mockConnection); + } + + @Test + void test_readFromCache() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenReturn(value); + byte[] result = spyConnection.readFromCache("myQueryKey"); + assertEquals(value, result); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).returnObject(mockConnection); + } + + @Test + void test_readFromCacheException() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenThrow(new RuntimeException("test")); + assertNull(spyConnection.readFromCache("myQueryKey")); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).invalidateObject(mockConnection); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) // false = CMD (standalone), true = CME (cluster) + void test_readAndWriteCache_BothModes(boolean isClusterMode) throws Exception { + // Setup connection with appropriate cluster mode + CacheConnection spyConnection = spy(cacheConnection); + spyConnection.setClusterMode(isClusterMode); + + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + doNothing().when(spyConnection).incrementInFlightSize(anyLong()); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + + String key = "testKey"; + byte[] value = "testValue".getBytes(StandardCharsets.UTF_8); + + // Test WRITE operation + if (isClusterMode) { + // Mock cluster connection for write + when(mockWriteConnPool.borrowObject()).thenReturn(mockClusterConnection); + when(mockClusterConnection.async()).thenReturn(mockClusterAsyncCommands); + when(mockClusterAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + } else { + // Mock standalone connection for write + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + } + + spyConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(spyConnection).incrementInFlightSize(anyLong()); + + // Test READ operation + if (isClusterMode) { + // Mock cluster connection for read + when(mockReadConnPool.borrowObject()).thenReturn(mockClusterConnection); + when(mockClusterConnection.sync()).thenReturn(mockClusterSyncCommands); + when(mockClusterSyncCommands.get(any())).thenReturn(value); + } else { + // Mock standalone connection for read + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenReturn(value); + } + + byte[] result = spyConnection.readFromCache(key); + assertEquals(value, result); + verify(mockReadConnPool).borrowObject(); + + // Verify appropriate connection type was used + if (isClusterMode) { + verify(mockClusterConnection).async(); + verify(mockClusterConnection).sync(); + verify(mockClusterAsyncCommands).set(any(), any(), any()); + verify(mockClusterSyncCommands).get(any()); + } else { + verify(mockConnection).async(); + verify(mockConnection).sync(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockSyncCommands).get(any()); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void test_readAndWriteException_BothModes(boolean isClusterMode) throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + spyConnection.setClusterMode(isClusterMode); + + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + doNothing().when(spyConnection).incrementInFlightSize(anyLong()); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + + String key = "testKey"; + byte[] value = "testValue".getBytes(StandardCharsets.UTF_8); + + // Test WRITE exception handling + if (isClusterMode) { + when(mockWriteConnPool.borrowObject()).thenReturn(mockClusterConnection); + when(mockClusterConnection.async()).thenReturn(mockClusterAsyncCommands); + when(mockClusterAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("cluster write error")); + } else { + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("standalone write error")); + } + + spyConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockWriteConnPool).invalidateObject(isClusterMode ? mockClusterConnection : mockConnection); + + // Test READ exception handling + if (isClusterMode) { + when(mockReadConnPool.borrowObject()).thenReturn(mockClusterConnection); + when(mockClusterConnection.sync()).thenReturn(mockClusterSyncCommands); + when(mockClusterSyncCommands.get(any())).thenThrow(new RuntimeException("cluster read error")); + } else { + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenThrow(new RuntimeException("standalone read error")); + } + + assertNull(spyConnection.readFromCache(key)); + verify(mockReadConnPool).borrowObject(); + verify(mockReadConnPool).invalidateObject(isClusterMode ? mockClusterConnection : mockConnection); + } + + @Test + void test_cacheConnectionPoolSize_default() throws Exception { + clearStaticRegistry(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + + CacheConnection connection = new CacheConnection(props); + // Bypass cluster detection for tests + connection.setClusterMode(false); + + // Create real pools (no network until borrow) + connection.triggerPoolInit(true); + connection.triggerPoolInit(false); + + GenericObjectPool> readPool = getInstancePool(connection,"readConnectionPool"); + GenericObjectPool> writePool = getInstancePool(connection, "writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(20, readPool.getMaxTotal()); + assertEquals(20, readPool.getMaxIdle()); + assertEquals(20, writePool.getMaxTotal()); + assertEquals(20, writePool.getMaxIdle()); + assertNotEquals(8, readPool.getMaxTotal()); // making sure it does not set the default values of Generic pool + assertNotEquals(8, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionPoolSize_Initialization() throws Exception { + clearStaticRegistry(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + props.setProperty("cacheConnectionPoolSize", "15"); + + CacheConnection connection = new CacheConnection(props); + // Bypass cluster detection for tests + connection.setClusterMode(false); + + // Create real pools (no network until borrow) + connection.triggerPoolInit(true); + connection.triggerPoolInit(false); + + GenericObjectPool> readPool = getInstancePool(connection,"readConnectionPool"); + GenericObjectPool> writePool = getInstancePool(connection, "writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(15, readPool.getMaxTotal()); + assertEquals(15, readPool.getMaxIdle()); + assertEquals(15, writePool.getMaxTotal()); + assertEquals(15, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionTimeout_Initialization() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheConnectionTimeout", "5000"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(5000), timeout); + } + + @Test + void test_cacheConnectionTimeout_default() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(2000), timeout, "default should be 2000 ms"); + } + + @SuppressWarnings("unchecked") + private static GenericObjectPool> getInstancePool(CacheConnection connection, String fieldName) throws Exception { + Field f = CacheConnection.class.getDeclaredField(fieldName); + f.setAccessible(true); + return (GenericObjectPool>) f.get(connection); + } + + private static void clearStaticRegistry() throws Exception { + Field registryField = CacheConnection.class.getDeclaredField("endpointToPoolRegistry"); + registryField.setAccessible(true); + java.util.concurrent.ConcurrentHashMap registry = + (java.util.concurrent.ConcurrentHashMap) registryField.get(null); + registry.clear(); + } + + @Test + void test_cacheMonitorIntegration() throws Exception { + CacheConnection spyConnection = spy(cacheConnection); + doNothing().when(spyConnection).reportErrorToCacheMonitor(anyBoolean(), any(), any()); + doNothing().when(spyConnection).incrementInFlightSize(anyLong()); + doNothing().when(spyConnection).decrementInFlightSize(anyLong()); + + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + + // DEGRADED state: operations bypassed + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.DEGRADED); + spyConnection.writeToCache(key, value, 100); + assertNull(spyConnection.readFromCache(key)); + verify(mockWriteConnPool, never()).borrowObject(); + verify(mockReadConnPool, never()).borrowObject(); + + // HEALTHY state: operations proceed + when(spyConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.HEALTHY); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + spyConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(spyConnection).incrementInFlightSize(anyLong()); + + // Error reporting on read failure + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + RuntimeException testException = new RuntimeException("Connection failed"); + when(mockSyncCommands.get(any())).thenThrow(testException); + assertNull(spyConnection.readFromCache(key)); + verify(spyConnection).reportErrorToCacheMonitor(eq(false), eq(testException), eq("READ")); + + // failWhenCacheDown: throws SQLException + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("failWhenCacheDown", "true"); + CacheConnection failConnection = spy(new CacheConnection(props)); + failConnection.setConnectionPools(mockReadConnPool, mockWriteConnPool); + when(failConnection.getClusterHealthStateFromCacheMonitor()).thenReturn(CacheMonitor.HealthState.DEGRADED); + SQLException exception = assertThrows(SQLException.class, () -> failConnection.readFromCache(key)); + assertTrue(exception.getMessage().contains("Cache cluster is in DEGRADED state")); + } + + @Test + void test_multiEndpoint_PoolReuseAndIsolation() throws Exception { + clearStaticRegistry(); + + // Test 1: Same endpoints should reuse pools (first wins) + Properties props1 = new Properties(); + props1.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props1.setProperty("cacheEndpointAddrRo", "localhost:6380"); + props1.setProperty("cacheConnectionPoolSize", "10"); + + Properties props2 = new Properties(); + props2.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props2.setProperty("cacheEndpointAddrRo", "localhost:6380"); + props2.setProperty("cacheConnectionPoolSize", "15"); + + CacheConnection connection1 = new CacheConnection(props1); + connection1.setClusterMode(false); + CacheConnection connection2 = new CacheConnection(props2); + connection2.setClusterMode(false); + + connection1.triggerPoolInit(true); + connection1.triggerPoolInit(false); + connection2.triggerPoolInit(true); + connection2.triggerPoolInit(false); + + GenericObjectPool> readPool1 = getInstancePool(connection1, "readConnectionPool"); + GenericObjectPool> readPool2 = getInstancePool(connection2, "readConnectionPool"); + GenericObjectPool> writePool1 = getInstancePool(connection1, "writeConnectionPool"); + GenericObjectPool> writePool2 = getInstancePool(connection2, "writeConnectionPool"); + + assertSame(readPool1, readPool2, "Read pools should be the same instance"); + assertSame(writePool1, writePool2, "Write pools should be the same instance"); + assertEquals(10, readPool1.getMaxTotal(), "Pool size should be 10 (first initialized)"); + assertEquals(10, writePool1.getMaxTotal(), "Pool size should be 10 (first initialized)"); + + // Test 2: Different endpoints should have isolated pools + Properties props3 = new Properties(); + props3.setProperty("cacheEndpointAddrRw", "localhost:7379"); + props3.setProperty("cacheEndpointAddrRo", "localhost:7380"); + props3.setProperty("cacheConnectionPoolSize", "20"); + + CacheConnection connection3 = new CacheConnection(props3); + connection3.setClusterMode(false); + connection3.triggerPoolInit(true); + connection3.triggerPoolInit(false); + + GenericObjectPool> readPool3 = getInstancePool(connection3, "readConnectionPool"); + GenericObjectPool> writePool3 = getInstancePool(connection3, "writeConnectionPool"); + + assertNotSame(readPool1, readPool3, "Read pools should be different instances"); + assertNotSame(writePool1, writePool3, "Write pools should be different instances"); + assertEquals(20, readPool3.getMaxTotal()); + assertEquals(20, writePool3.getMaxTotal()); + + // Test 3: Same RW endpoint, different RO endpoints (or no RO) + Properties props4 = new Properties(); + props4.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection4 = new CacheConnection(props4); + connection4.setClusterMode(false); + connection4.triggerPoolInit(false); + connection4.triggerPoolInit(true); + + GenericObjectPool> writePool4 = getInstancePool(connection4, "writeConnectionPool"); + GenericObjectPool> readPool4 = getInstancePool(connection4, "readConnectionPool"); + + assertSame(writePool1, writePool4, "Write pools should be shared for same RW endpoint"); + assertEquals(10, writePool4.getMaxTotal(), "Connection pool size should not be changed."); + assertNotSame(readPool1, readPool4, "Read pools should be different for different RO endpoints"); + } + + @Test + void test_multiEndpoint_ConcurrentInitialization() throws Exception { + clearStaticRegistry(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheConnectionPoolSize", "10"); + + CacheConnection connection1 = new CacheConnection(props); + connection1.setClusterMode(false); + CacheConnection connection2 = new CacheConnection(props); + connection2.setClusterMode(false); + CacheConnection connection3 = new CacheConnection(props); + connection3.setClusterMode(false); + + // Simulate concurrent initialization + Thread t1 = new Thread(() -> connection1.triggerPoolInit(false)); + Thread t2 = new Thread(() -> connection2.triggerPoolInit(false)); + Thread t3 = new Thread(() -> connection3.triggerPoolInit(false)); + + t1.start(); + t2.start(); + t3.start(); + + t1.join(); + t2.join(); + t3.join(); + + // All should reference the same pool + GenericObjectPool> pool1 = getInstancePool(connection1, "writeConnectionPool"); + GenericObjectPool> pool2 = getInstancePool(connection2, "writeConnectionPool"); + GenericObjectPool> pool3 = getInstancePool(connection3, "writeConnectionPool"); + + assertSame(pool1, pool2); + assertSame(pool2, pool3); + assertEquals(10, pool1.getMaxTotal()); + } + + private Object getField(Object obj, String fieldName) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(obj); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheMonitorTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheMonitorTest.java new file mode 100644 index 000000000..0e0646e5e --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheMonitorTest.java @@ -0,0 +1,553 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.RedisConnectionException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +public class CacheMonitorTest { + private Properties props; + private AutoCloseable closeable; + + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryCounter mockStateTransitionCounter; + @Mock private TelemetryCounter mockHealthCheckSuccessCounter; + @Mock private TelemetryCounter mockHealthCheckFailureCounter; + @Mock private TelemetryCounter mockErrorCounter; + @Mock private TelemetryGauge mockConsecutiveSuccessGauge; + @Mock private TelemetryGauge mockConsecutiveFailureGauge; + + @Mock private CachePingConnection mockRwPingConnection; + @Mock private CachePingConnection mockRoPingConnection; + + @BeforeEach + void setUp() throws Exception { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + + // Setup telemetry mocks + when(mockTelemetryFactory.createCounter("JdbcCacheStateTransitionCount")) + .thenReturn(mockStateTransitionCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheHealthCheckSuccessCount")) + .thenReturn(mockHealthCheckSuccessCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheHealthCheckFailureCount")) + .thenReturn(mockHealthCheckFailureCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheErrorCount")) + .thenReturn(mockErrorCounter); + when(mockTelemetryFactory.createGauge(eq("JdbcCacheConsecutiveSuccessCount"), any())) + .thenReturn(mockConsecutiveSuccessGauge); + when(mockTelemetryFactory.createGauge(eq("JdbcCacheConsecutiveFailureCount"), any())) + .thenReturn(mockConsecutiveFailureGauge); + + // Reset singleton state between tests + resetCacheMonitorState(); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + private void registerCluster(String rwEndpoint, String roEndpoint) throws Exception { + long inFlightLimit = CacheMonitor.CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT.getLong(props); + boolean healthCheckInHealthy = CacheMonitor.CACHE_HEALTH_CHECK_IN_HEALTHY_STATE.getBoolean(props); + + CacheMonitor.registerCluster(inFlightLimit, healthCheckInHealthy, null, rwEndpoint, roEndpoint, + false, Duration.ofSeconds(5), false, null, null, null, null, null, false, false); + + CacheMonitor.ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + if (instance != null) { + instance.setPingConnections(cluster, mockRwPingConnection, + roEndpoint != null ? mockRoPingConnection : null); + } + } + + private void registerClusterWithTelemetry(String rwEndpoint, String roEndpoint) throws Exception { + long inFlightLimit = CacheMonitor.CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT.getLong(props); + boolean healthCheckInHealthy = CacheMonitor.CACHE_HEALTH_CHECK_IN_HEALTHY_STATE.getBoolean(props); + + CacheMonitor.registerCluster(inFlightLimit, healthCheckInHealthy, mockTelemetryFactory, rwEndpoint, roEndpoint, + false, Duration.ofSeconds(5), false, null, null, null, null, null, false, false); + + CacheMonitor.ClusterHealthState cluster = getCluster(rwEndpoint, roEndpoint); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + if (instance != null) { + instance.setPingConnections(cluster, mockRwPingConnection, + roEndpoint != null ? mockRoPingConnection : null); + } + } + + /** + * Reset CacheMonitor singleton state between tests using reflection + * to prevent test pollution from static fields. + */ + private void resetCacheMonitorState() throws Exception { + setStaticField("instance", null); + setStaticField("monitorThreadStarted", false); + + @SuppressWarnings("unchecked") + Map clusterStates = + (Map) getStaticField("clusterStates"); + if (clusterStates != null) { + clusterStates.clear(); + } + + setStaticField("stateTransitionCounter", null); + setStaticField("healthCheckSuccessCounter", null); + setStaticField("healthCheckFailureCounter", null); + setStaticField("errorCounter", null); + setStaticField("consecutiveSuccessGauge", null); + setStaticField("consecutiveFailureGauge", null); + } + + private static Object getStaticField(String fieldName) throws Exception { + Field field = CacheMonitor.class.getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(null); + } + + private static void setStaticField(String fieldName, Object value) throws Exception { + Field field = CacheMonitor.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(null, value); + } + + private void setInstanceField(Object instance, String fieldName, Object value) throws Exception { + Field field = instance.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(instance, value); + } + + @SuppressWarnings("unchecked") + private Map getClusterStates() throws Exception { + return (Map) getStaticField("clusterStates"); + } + + private CacheMonitor.ClusterHealthState getCluster(String rwEndpoint, String roEndpoint) throws Exception { + Map clusterStates = getClusterStates(); + String key = CacheMonitor.ClusterHealthState.generateClusterKey(rwEndpoint, roEndpoint); + return clusterStates.get(key); + } + + @Test + void testRegisterCluster() throws Exception { + // Test 1: RW-only endpoint + registerCluster("localhost:6379", null); + assertEquals(1, getClusterStates().size()); + assertEquals(CacheMonitor.HealthState.HEALTHY, + CacheMonitor.getClusterState("localhost:6379", null)); + + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + assertNotNull(cluster); + assertEquals("localhost:6379", cluster.rwEndpoint); + assertNull(cluster.roEndpoint); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + assertEquals(0, cluster.consecutiveRwSuccesses); + assertEquals(0, cluster.consecutiveRwFailures); + assertEquals(0L, cluster.inFlightWriteSizeBytes.get()); + + Object instance = getStaticField("instance"); + assertNotNull(instance); + Boolean monitorThreadStarted = (Boolean) getStaticField("monitorThreadStarted"); + assertFalse(monitorThreadStarted); + + // Test 2: Dual endpoints + registerCluster("localhost:6380", "localhost:6381"); + assertEquals(2, getClusterStates().size()); + cluster = getCluster("localhost:6380", "localhost:6381"); + assertNotNull(cluster); + assertEquals("localhost:6380", cluster.rwEndpoint); + assertEquals("localhost:6381", cluster.roEndpoint); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.roHealthState); + + // Test 3: Duplicate registration should not create duplicate + registerCluster("localhost:6380", "localhost:6381"); + assertEquals(2, getClusterStates().size()); + + // Test 4: Same RW/RO endpoint should normalize RO to null + registerCluster("localhost:6382", "localhost:6382"); + cluster = getCluster("localhost:6382", "localhost:6382"); + assertNotNull(cluster); + assertNull(cluster.roEndpoint); + assertEquals(3, getClusterStates().size()); + } + + @Test + void testReportError() throws Exception { + registerClusterWithTelemetry("localhost:6379", "localhost:6380"); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", "localhost:6380"); + + // Test 1: Recoverable error transitions to SUSPECT + CacheMonitor.reportError("localhost:6379", "localhost:6380", true, + new RedisConnectionException("Connection refused"), "SET"); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.roHealthState); + verify(mockErrorCounter, times(1)).inc(); + verify(mockStateTransitionCounter, times(1)).inc(); + + // Test 2: Non-recoverable error doesn't change state + CacheMonitor.reportError("localhost:6379", "localhost:6380", true, + new RuntimeException("Serialization failed"), "SET"); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + verify(mockErrorCounter, times(2)).inc(); + verify(mockStateTransitionCounter, times(1)).inc(); + + // Test 3: RO endpoint error transitions independently + CacheMonitor.reportError("localhost:6379", "localhost:6380", false, + new RedisCommandExecutionException("READONLY"), "SET"); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.roHealthState); + verify(mockErrorCounter, times(3)).inc(); + verify(mockStateTransitionCounter, times(2)).inc(); + + // Test 4: Unregistered cluster logs warning without metrics + CacheMonitor.reportError("localhost:9999", null, true, + new RedisConnectionException("Connection refused"), "SET"); + verify(mockErrorCounter, times(3)).inc(); + verify(mockStateTransitionCounter, times(2)).inc(); + } + + @Test + void testInFlightSize_MemoryPressureScenarios() throws Exception { + props.setProperty(CacheMonitor.CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT.name, "1000"); + registerClusterWithTelemetry("localhost:6379", null); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + + // Test 1: Below limit remains healthy + CacheMonitor.incrementInFlightSizeStatic("localhost:6379", null, 500); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + assertEquals(500L, cluster.inFlightWriteSizeBytes.get()); + + // Test 2: Exceeds limit transitions to degraded + CacheMonitor.incrementInFlightSizeStatic("localhost:6379", null, 1000); + assertEquals(CacheMonitor.HealthState.DEGRADED, cluster.rwHealthState); + assertEquals(1500L, cluster.inFlightWriteSizeBytes.get()); + verify(mockStateTransitionCounter, times(1)).inc(); + + // Test 3: Decrement reduces size + CacheMonitor.decrementInFlightSizeStatic("localhost:6379", null, 900); + assertEquals(600L, cluster.inFlightWriteSizeBytes.get()); + + // Test 4: Decrement never goes negative + CacheMonitor.decrementInFlightSizeStatic("localhost:6379", null, 1000); + assertEquals(0L, cluster.inFlightWriteSizeBytes.get()); + } + + @Test + void testClusterStateAggregation() throws Exception { + // Test 1: Single endpoint state transitions + registerCluster("localhost:6379", null); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + assertEquals(CacheMonitor.HealthState.HEALTHY, + CacheMonitor.getClusterState("localhost:6379", null)); + + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + assertEquals(CacheMonitor.HealthState.SUSPECT, + CacheMonitor.getClusterState("localhost:6379", null)); + + cluster.transitionToState(CacheMonitor.HealthState.DEGRADED, true, "test_setup", null); + assertEquals(CacheMonitor.HealthState.DEGRADED, + CacheMonitor.getClusterState("localhost:6379", null)); + + // Test 2: Dual endpoints + registerCluster("localhost:6379", "localhost:6380"); + cluster = getCluster("localhost:6379", "localhost:6380"); + assertEquals(CacheMonitor.HealthState.HEALTHY, + CacheMonitor.getClusterState("localhost:6379", "localhost:6380")); + + cluster.transitionToState(CacheMonitor.HealthState.DEGRADED, true, "test_setup", null); + assertEquals(CacheMonitor.HealthState.DEGRADED, + CacheMonitor.getClusterState("localhost:6379", "localhost:6380")); + + cluster.transitionToState(CacheMonitor.HealthState.HEALTHY, true, "test_setup", null); + cluster.transitionToState(CacheMonitor.HealthState.DEGRADED, false, "test_setup", null); + assertEquals(CacheMonitor.HealthState.DEGRADED, + CacheMonitor.getClusterState("localhost:6379", "localhost:6380")); + + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + cluster.transitionToState(CacheMonitor.HealthState.HEALTHY, false, "test_setup", null); + assertEquals(CacheMonitor.HealthState.SUSPECT, + CacheMonitor.getClusterState("localhost:6379", "localhost:6380")); + + // Test 3: Unregistered cluster defaults to healthy + assertEquals(CacheMonitor.HealthState.HEALTHY, + CacheMonitor.getClusterState("nonexistent:9999", null)); + } + + @Test + void testExecutePing_StateTransitions() throws Exception { + registerClusterWithTelemetry("localhost:6379", null); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + + // Success: maintains healthy state + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(true); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + assertEquals(1, cluster.consecutiveRwSuccesses); + verify(mockHealthCheckSuccessCounter, times(1)).inc(); + + // Failure: HEALTHY → SUSPECT + when(mockRwPingConnection.ping()).thenReturn(false); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + verify(mockHealthCheckFailureCounter, times(1)).inc(); + verify(mockStateTransitionCounter, times(1)).inc(); + + // Three consecutive failures: SUSPECT → DEGRADED + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.DEGRADED, cluster.rwHealthState); + verify(mockStateTransitionCounter, times(2)).inc(); + + // Recovery: Three successes from SUSPECT → HEALTHY + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + when(mockRwPingConnection.ping()).thenReturn(true); + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + verify(mockStateTransitionCounter, times(3)).inc(); + } + + @Test + void testExecutePing_EdgeCases() throws Exception { + // Dual endpoints track independently + registerClusterWithTelemetry("localhost:6379", "localhost:6380"); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", "localhost:6380"); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(false); + when(mockRoPingConnection.isOpen()).thenReturn(true); + when(mockRoPingConnection.ping()).thenReturn(true); + + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, false); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.roHealthState); + + // Connection closed = failure (no ping call) + reset(mockRwPingConnection); + when(mockRwPingConnection.isOpen()).thenReturn(false); + invokeExecutePing(instance, cluster, true); + verify(mockRwPingConnection, never()).ping(); + + // Exception during ping = failure + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenThrow(new RuntimeException("timeout")); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.SUSPECT, cluster.rwHealthState); + + // Recovery from DEGRADED when memory clears + resetCacheMonitorState(); + props.setProperty(CacheMonitor.CACHE_IN_FLIGHT_WRITE_SIZE_LIMIT.name, "1000"); + registerClusterWithTelemetry("localhost:6379", null); + cluster = getCluster("localhost:6379", null); + instance = (CacheMonitor) getStaticField("instance"); + cluster.transitionToState(CacheMonitor.HealthState.DEGRADED, true, "test_setup", null); + cluster.inFlightWriteSizeBytes.set(500); + + reset(mockRwPingConnection); + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(true); + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + invokeExecutePing(instance, cluster, true); + assertEquals(CacheMonitor.HealthState.HEALTHY, cluster.rwHealthState); + } + + private void invokeExecutePing(CacheMonitor instance, CacheMonitor.ClusterHealthState cluster, boolean isRw) throws Exception { + Method method = CacheMonitor.class.getDeclaredMethod("executePing", CacheMonitor.ClusterHealthState.class, boolean.class); + method.setAccessible(true); + method.invoke(instance, cluster, isRw); + } + + @Test + void testRun_MonitoringBehavior() throws Exception { + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(true); + + // HEALTHY: skips ping by default + registerClusterWithTelemetry("localhost:6379", null); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + CacheMonitor spy = spy(instance); + AtomicInteger sleepCount = new AtomicInteger(0); + CacheMonitor finalSpy = spy; + doAnswer(inv -> { + if (sleepCount.incrementAndGet() >= 2) setInstanceField(finalSpy, "stopped", true); + return null; + }).when(spy).sleep(anyLong()); + spy.run(); + verify(mockRwPingConnection, never()).ping(); + + // HEALTHY with health check enabled: executes ping + resetCacheMonitorState(); + props.setProperty(CacheMonitor.CACHE_HEALTH_CHECK_IN_HEALTHY_STATE.name, "true"); + registerClusterWithTelemetry("localhost:6379", null); + instance = (CacheMonitor) getStaticField("instance"); + spy = spy(instance); + sleepCount.set(0); + CacheMonitor finalSpy1 = spy; + doAnswer(inv -> { + if (sleepCount.incrementAndGet() >= 2) setInstanceField(finalSpy1, "stopped", true); + return null; + }).when(spy).sleep(anyLong()); + spy.run(); + verify(mockRwPingConnection, times(2)).ping(); + + // SUSPECT/DEGRADED: always pings + resetCacheMonitorState(); + props.remove("cacheHealthCheckInHealthyState"); + registerClusterWithTelemetry("localhost:6379", null); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + instance = (CacheMonitor) getStaticField("instance"); + spy = spy(instance); + sleepCount.set(0); + CacheMonitor finalSpy2 = spy; + doAnswer(inv -> { + if (sleepCount.incrementAndGet() >= 2) setInstanceField(finalSpy2, "stopped", true); + return null; + }).when(spy).sleep(anyLong()); + spy.run(); + verify(mockRwPingConnection, times(4)).ping(); + + // Dual endpoints: pings both + resetCacheMonitorState(); + registerClusterWithTelemetry("localhost:6379", "localhost:6380"); + cluster = getCluster("localhost:6379", "localhost:6380"); + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + instance = (CacheMonitor) getStaticField("instance"); + reset(mockRwPingConnection, mockRoPingConnection); + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(true); + when(mockRoPingConnection.isOpen()).thenReturn(true); + when(mockRoPingConnection.ping()).thenReturn(true); + spy = spy(instance); + sleepCount.set(0); + CacheMonitor finalSpy3 = spy; + doAnswer(inv -> { + if (sleepCount.incrementAndGet() >= 2) setInstanceField(finalSpy3, "stopped", true); + return null; + }).when(spy).sleep(anyLong()); + spy.run(); + verify(mockRwPingConnection, times(2)).ping(); + verify(mockRoPingConnection, times(2)).ping(); + + // Exception during ping: continues monitoring + resetCacheMonitorState(); + registerClusterWithTelemetry("localhost:6379", null); + cluster = getCluster("localhost:6379", null); + cluster.transitionToState(CacheMonitor.HealthState.SUSPECT, true, "test_setup", null); + instance = (CacheMonitor) getStaticField("instance"); + reset(mockRwPingConnection); + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()) + .thenThrow(new RuntimeException("Test exception")) + .thenReturn(true); + spy = spy(instance); + sleepCount.set(0); + CacheMonitor finalSpy4 = spy; + doAnswer(inv -> { + if (sleepCount.incrementAndGet() >= 2) setInstanceField(finalSpy4, "stopped", true); + return null; + }).when(spy).sleep(anyLong()); + spy.run(); + verify(mockRwPingConnection, times(2)).ping(); + } + + @Test + void testPrivateHelperMethods() throws Exception { + // Test 1: classifyError categorizes exceptions correctly + Method classifyMethod = CacheMonitor.class.getDeclaredMethod("classifyError", Throwable.class); + classifyMethod.setAccessible(true); + + assertEquals(CacheMonitor.ErrorCategory.CONNECTION, + classifyMethod.invoke(null, new RedisConnectionException("Connection refused"))); + assertEquals(CacheMonitor.ErrorCategory.COMMAND, + classifyMethod.invoke(null, new RedisCommandExecutionException("READONLY"))); + assertEquals(CacheMonitor.ErrorCategory.COMMAND, + classifyMethod.invoke(null, new RedisCommandExecutionException("WRONGTYPE"))); + assertEquals(CacheMonitor.ErrorCategory.RESOURCE, + classifyMethod.invoke(null, new RedisCommandExecutionException("OOM"))); + assertEquals(CacheMonitor.ErrorCategory.RESOURCE, + classifyMethod.invoke(null, new RedisCommandExecutionException("CLUSTERDOWN"))); + assertEquals(CacheMonitor.ErrorCategory.RESOURCE, + classifyMethod.invoke(null, new RedisCommandExecutionException((String) null))); + assertEquals(CacheMonitor.ErrorCategory.CONNECTION, + classifyMethod.invoke(null, new io.lettuce.core.RedisException("Generic error"))); + assertEquals(CacheMonitor.ErrorCategory.DATA, + classifyMethod.invoke(null, new RuntimeException("Serialization failed"))); + + // Test 2: isRecoverableError determines recoverability + Method recoverableMethod = CacheMonitor.class.getDeclaredMethod("isRecoverableError", CacheMonitor.ErrorCategory.class); + recoverableMethod.setAccessible(true); + + assertTrue((Boolean) recoverableMethod.invoke(null, CacheMonitor.ErrorCategory.CONNECTION)); + assertTrue((Boolean) recoverableMethod.invoke(null, CacheMonitor.ErrorCategory.COMMAND)); + assertTrue((Boolean) recoverableMethod.invoke(null, CacheMonitor.ErrorCategory.RESOURCE)); + assertFalse((Boolean) recoverableMethod.invoke(null, CacheMonitor.ErrorCategory.DATA)); + + // Test 3: ping method handles various connection states + registerCluster("localhost:6379", null); + CacheMonitor.ClusterHealthState cluster = getCluster("localhost:6379", null); + CacheMonitor instance = (CacheMonitor) getStaticField("instance"); + + Method pingMethod = CacheMonitor.class.getDeclaredMethod("ping", CacheMonitor.ClusterHealthState.class, boolean.class); + pingMethod.setAccessible(true); + + cluster.rwPingConnection = null; + assertFalse((Boolean) pingMethod.invoke(instance, cluster, true)); + + cluster.rwPingConnection = mockRwPingConnection; + when(mockRwPingConnection.isOpen()).thenReturn(false); + assertFalse((Boolean) pingMethod.invoke(instance, cluster, true)); + + when(mockRwPingConnection.isOpen()).thenReturn(true); + when(mockRwPingConnection.ping()).thenReturn(true); + assertTrue((Boolean) pingMethod.invoke(instance, cluster, true)); + + when(mockRwPingConnection.ping()).thenThrow(new RuntimeException("Ping failed")); + assertFalse((Boolean) pingMethod.invoke(instance, cluster, true)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java new file mode 100644 index 000000000..458ae96cc --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.Test; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public class CacheSupplierTest { + + @Test + void testMemoizeWithExpiration_ValidParameters() { + Supplier delegate = () -> "test-value"; + + Supplier cached = CachedSupplier.memoizeWithExpiration(delegate, 1, TimeUnit.SECONDS); + + assertNotNull(cached); + assertEquals("test-value", cached.get()); + } + + @Test + void testMemoizeWithExpiration_NullDelegate() { + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(null, 1, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NullTimeUnit() { + Supplier delegate = () -> "test"; + + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 1, null)); + } + + @Test + void testMemoizeWithExpiration_ZeroDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 0, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NegativeDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, -1, TimeUnit.SECONDS)); + } + + @Test + void testCaching_DelegateCalledOnce() { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("cached-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 1, TimeUnit.SECONDS); + + // Call multiple times quickly + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + + // Delegate should only be called once due to caching + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpiration_DelegateCalledAgainAfterExpiry() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("value1", "value2"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 50, TimeUnit.MILLISECONDS); + + // First call + assertEquals("value1", cached.get()); + verify(mockDelegate, times(1)).get(); + + // Wait for expiration + Thread.sleep(100); + + // Second call after expiration + assertEquals("value2", cached.get()); + verify(mockDelegate, times(2)).get(); + } + + @Test + void testConcurrentAccess() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("concurrent-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 5, TimeUnit.SECONDS); + + // Simulate concurrent access + Thread[] threads = new Thread[10]; + String[] results = new String[10]; + + for (int i = 0; i < 10; i++) { + final int index = i; + threads[i] = new Thread(() -> results[index] = cached.get()); + threads[i].start(); + } + + // Wait for all threads + for (Thread thread : threads) { + thread.join(); + } + + // All should get the same cached value + for (String result : results) { + assertEquals("concurrent-value", result); + } + + // Delegate should only be called once despite concurrent access + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpirationNanos_EdgeCase() { + Supplier timeSupplier = () -> System.nanoTime(); + + Supplier cached = CachedSupplier.memoizeWithExpiration(timeSupplier, 1, TimeUnit.NANOSECONDS); + + Long first = cached.get(); + Long second = cached.get(); + + // Due to very short expiration, second call might get different value + assertNotNull(first); + assertNotNull(second); + } + + @Test + void testPrivateConstructor() { + // Verify utility class has private constructor + assertThrows(Exception.class, () -> { + java.lang.reflect.Constructor constructor = + CachedSupplier.class.getDeclaredConstructor(); + constructor.setAccessible(true); + constructor.newInstance(); + }); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java new file mode 100644 index 000000000..eaa1587a2 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -0,0 +1,872 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +import java.sql.*; +import java.sql.Date; +import java.time.*; +import java.util.*; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import java.net.URL; +import java.net.MalformedURLException; + +import java.math.BigDecimal; + +public class CachedResultSetTest { + private CachedResultSet testResultSet; + @Mock ResultSet mockResultSet; + @Mock ResultSetMetaData mockResultSetMetadata; + private AutoCloseable closeable; + private static final Calendar estCal = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + private final TimeZone defaultTimeZone = TimeZone.getDefault(); + + // Column values: label, name, typeName, type, displaySize, precision, tableName, + // scale, schemaName, isAutoIncrement, isCaseSensitive, isCurrency, isDefinitelyWritable, + // isNullable, isReadOnly, isSearchable, isSigned, isWritable + private static final Object [][] testColumnMetadata = { + {"fieldNull", "fieldNull", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldInt", "fieldInt", "Integer", Types.INTEGER, 10, 2, "table", 1, "public", true, false, false, false, 0, false, true, true, true}, + {"fieldString", "fieldString", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldBoolean", "fieldBoolean", "Boolean", Types.BOOLEAN, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldByte", "fieldByte", "Byte", Types.TINYINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldShort", "fieldShort", "Short", Types.SMALLINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldLong", "fieldLong", "Long", Types.BIGINT, 10, 2, "table", 1, "public", false, false, false, false, 1, false, true, false, false}, + {"fieldFloat", "fieldFloat", "Float", Types.REAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDouble", "fieldDouble", "Double", Types.DOUBLE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldBigDecimal", "fieldBigDecimal", "BigDecimal", Types.DECIMAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDate", "fieldDate", "Date", Types.DATE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldTime", "fieldTime", "Time", Types.TIME, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldDateTime", "fieldDateTime", "Timestamp", Types.TIMESTAMP, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldSqlXml", "fieldSqlXml", "SqlXml", Types.SQLXML, 100, 1, "table", 1, "public", false, false, false, false, 0, true, true, false, false} + }; + + private static final Object [][] testColumnValues = { + {null, null}, + {1, 123456}, + {"John Doe", "Tony Stark"}, + {true, false}, + {(byte)100, (byte)70}, // Letter d and F in ASCII + {(short)55, (short)135}, + {2^33L, -2^35L}, + {3.14159f, -233.14159f}, + {2345.23345d, -2344355.4543d}, + {new BigDecimal("15.33"), new BigDecimal("-12.45")}, + {Date.valueOf("2025-03-15"), Date.valueOf("1102-01-15")}, + {Time.valueOf("22:54:00"), Time.valueOf("01:10:00")}, + {Timestamp.valueOf("2025-03-15 22:54:00"), Timestamp.valueOf("1950-01-18 21:50:05")}, + {new CachedSQLXML("A"), new CachedSQLXML("Value AValue B")} + }; + + private void mockGetMetadataFields(int column, int testMetadataCol) throws SQLException { + when(mockResultSetMetadata.getCatalogName(column)).thenReturn(""); + when(mockResultSetMetadata.getColumnClassName(column)).thenReturn("MyClass" + testMetadataCol); + when(mockResultSetMetadata.getColumnLabel(column)).thenReturn((String) testColumnMetadata[testMetadataCol][0]); + when(mockResultSetMetadata.getColumnName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][1]); + when(mockResultSetMetadata.getColumnTypeName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][2]); + when(mockResultSetMetadata.getColumnType(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][3]); + when(mockResultSetMetadata.getColumnDisplaySize(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][4]); + when(mockResultSetMetadata.getPrecision(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][5]); + when(mockResultSetMetadata.getTableName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][6]); + when(mockResultSetMetadata.getScale(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][7]); + when(mockResultSetMetadata.getSchemaName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][8]); + when(mockResultSetMetadata.isAutoIncrement(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][9]); + when(mockResultSetMetadata.isCaseSensitive(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][10]); + when(mockResultSetMetadata.isCurrency(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][11]); + when(mockResultSetMetadata.isDefinitelyWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][12]); + when(mockResultSetMetadata.isNullable(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][13]); + when(mockResultSetMetadata.isReadOnly(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][14]); + when(mockResultSetMetadata.isSearchable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][15]); + when(mockResultSetMetadata.isSigned(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][16]); + when(mockResultSetMetadata.isWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][17]); + } + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")); + } + + @AfterEach + void cleanUp() { + TimeZone.setDefault(defaultTimeZone); + } + + void setUpDefaultTestResultSet() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(testColumnMetadata.length); + for (int i = 0; i < testColumnMetadata.length; i++) { + mockGetMetadataFields(1+i, i); + when(mockResultSet.getObject(1+i)).thenReturn(testColumnValues[i][0], testColumnValues[i][1]); + } + when(mockResultSet.next()).thenReturn(true, true, false); + testResultSet = new CachedResultSet(mockResultSet); + } + + private void verifyDefaultMetadata(ResultSet rs) throws SQLException { + ResultSetMetaData md = rs.getMetaData(); + for (int i = 0; i < md.getColumnCount(); i++) { + assertEquals("", md.getCatalogName(i+1)); + assertEquals("MyClass" + i, md.getColumnClassName(i+1)); + assertEquals(testColumnMetadata[i][0], md.getColumnLabel(i+1)); + assertEquals(testColumnMetadata[i][1], md.getColumnName(i+1)); + assertEquals(testColumnMetadata[i][2], md.getColumnTypeName(i+1)); + assertEquals(testColumnMetadata[i][3], md.getColumnType(i+1)); + assertEquals(testColumnMetadata[i][4], md.getColumnDisplaySize(i+1)); + assertEquals(testColumnMetadata[i][5], md.getPrecision(i+1)); + assertEquals(testColumnMetadata[i][6], md.getTableName(i+1)); + assertEquals(testColumnMetadata[i][7], md.getScale(i+1)); + assertEquals(testColumnMetadata[i][8], md.getSchemaName(i+1)); + assertEquals(testColumnMetadata[i][9], md.isAutoIncrement(i+1)); + assertEquals(testColumnMetadata[i][10], md.isCaseSensitive(i+1)); + assertEquals(testColumnMetadata[i][11], md.isCurrency(i+1)); + assertEquals(testColumnMetadata[i][12], md.isDefinitelyWritable(i+1)); + assertEquals(testColumnMetadata[i][13], md.isNullable(i+1)); + assertEquals(testColumnMetadata[i][14], md.isReadOnly(i+1)); + assertEquals(testColumnMetadata[i][15], md.isSearchable(i+1)); + assertEquals(testColumnMetadata[i][16], md.isSigned(i+1)); + assertEquals(testColumnMetadata[i][17], md.isWritable(i+1)); + } + } + + private void verifyDefaultRow(ResultSet rs, int row) throws SQLException { + assertFalse(rs.wasNull()); + assertNull(rs.getObject(1)); // fieldNull + assertEquals(1, rs.findColumn("fieldNull")); + assertTrue(rs.wasNull()); + assertEquals((int) testColumnValues[1][row], rs.getInt(2)); // fieldInt + assertFalse(rs.wasNull()); + assertEquals((int) testColumnValues[1][row], rs.getInt("fieldInt")); + assertEquals(2, rs.findColumn("fieldInt")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[2][row], rs.getString(3)); // fieldString + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[2][row], rs.getString("fieldString")); + assertEquals(3, rs.findColumn("fieldString")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[3][row], rs.getBoolean(4)); // fieldBoolean + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[3][row], rs.getBoolean("fieldBoolean")); + assertEquals(4, rs.findColumn("fieldBoolean")); + assertFalse(rs.wasNull()); + assertEquals((byte) testColumnValues[4][row], rs.getByte(5)); // fieldByte + assertFalse(rs.wasNull()); + assertEquals((byte) testColumnValues[4][row], rs.getByte("fieldByte")); + assertEquals(5, rs.findColumn("fieldByte")); + assertFalse(rs.wasNull()); + assertEquals((short) testColumnValues[5][row], rs.getShort(6)); // fieldShort + assertFalse(rs.wasNull()); + assertEquals((short) testColumnValues[5][row], rs.getShort("fieldShort")); + assertEquals(6, rs.findColumn("fieldShort")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject("fieldNull")); + assertTrue(rs.wasNull()); + assertEquals((Long) testColumnValues[6][row], rs.getLong(7)); // fieldLong + assertFalse(rs.wasNull()); + assertEquals((Long) testColumnValues[6][row], rs.getLong("fieldLong")); + assertEquals(7, rs.findColumn("fieldLong")); + assertFalse(rs.wasNull()); + assertEquals((float) testColumnValues[7][row], rs.getFloat(8), 0); // fieldFloat + assertFalse(rs.wasNull()); + assertEquals((float) testColumnValues[7][row], rs.getFloat("fieldFloat"), 0); + assertEquals(8, rs.findColumn("fieldFloat")); + assertFalse(rs.wasNull()); + assertEquals((double) testColumnValues[8][row], rs.getDouble(9)); // fieldDouble + assertFalse(rs.wasNull()); + assertEquals((double) testColumnValues[8][row], rs.getDouble("fieldDouble")); + assertEquals(9, rs.findColumn("fieldDouble")); + assertFalse(rs.wasNull()); + assertEquals(0, rs.getBigDecimal(10).compareTo((BigDecimal) testColumnValues[9][row])); // fieldBigDecimal + assertFalse(rs.wasNull()); + assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo((BigDecimal) testColumnValues[9][row])); + assertEquals(10, rs.findColumn("fieldBigDecimal")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject(1)); // fieldNull + assertTrue(rs.wasNull()); + assertEquals(testColumnValues[10][row], rs.getDate(11)); // fieldDate + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[10][row], rs.getDate("fieldDate")); + assertEquals(11, rs.findColumn("fieldDate")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[11][row], rs.getTime(12)); // fieldTime + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[11][row], rs.getTime("fieldTime")); + assertEquals(12, rs.findColumn("fieldTime")); + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[12][row], rs.getTimestamp(13)); // fieldDateTime + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[12][row], rs.getTimestamp("fieldDateTime")); + assertEquals(13, rs.findColumn("fieldDateTime")); + assertFalse(rs.wasNull()); + String sqlXmlString = ((SQLXML)testColumnValues[13][row]).getString(); + assertEquals(sqlXmlString, rs.getSQLXML(14).getString()); // fieldSqlXml + assertFalse(rs.wasNull()); + assertEquals(sqlXmlString, rs.getSQLXML("fieldSqlXml").getString()); + assertEquals(14, rs.findColumn("fieldSqlXml")); + assertFalse(rs.wasNull()); + verifyNonexistingField(rs); + } + + private void verifyNonexistingField(ResultSet rs) { + try { + rs.getObject("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } + try { + rs.findColumn("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } + } + + @Test + void test_basic_cached_result_set() throws Exception { + // Basic verification of the test result set + setUpDefaultTestResultSet(); + verifyDefaultMetadata(testResultSet); + assertEquals(0, testResultSet.getRow()); + assertTrue(testResultSet.next()); + assertEquals(1, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 0); + assertTrue(testResultSet.next()); + assertEquals(2, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 1); + assertFalse(testResultSet.next()); + assertEquals(0, testResultSet.getRow()); + assertNull(testResultSet.getWarnings()); + testResultSet.clearWarnings(); + assertNull(testResultSet.getWarnings()); + testResultSet.beforeFirst(); + // Test serialization and de-serialization of the result set + byte[] serialized_data = testResultSet.serializeIntoByteArray(); + ResultSet rs = CachedResultSet.deserializeFromByteArray(serialized_data); + verifyDefaultMetadata(rs); + assertTrue(rs.next()); + verifyDefaultRow(rs, 0); + assertTrue(rs.next()); + verifyDefaultRow(rs, 1); + assertFalse(rs.next()); + assertNull(rs.getWarnings()); + rs.relative(-10); // We should be before the start of the rows + assertTrue(rs.isBeforeFirst()); + assertEquals(0, rs.getRow()); + rs.relative(10); // We should be after the end of the rows + assertTrue(rs.isAfterLast()); + assertEquals(0, rs.getRow()); + rs.absolute(-10); // We should be before the start of the rows + assertTrue(rs.isBeforeFirst()); + assertFalse(rs.absolute(100)); // Jump to after the end of the rows + assertTrue(rs.isAfterLast()); + assertEquals(0, rs.getRow()); + assertFalse(rs.absolute(0)); // Go to the beginning of rows + assertTrue(rs.isBeforeFirst()); + assertTrue(rs.next()); // We are at first row + verifyDefaultRow(rs, 0); + rs.relative(1); // Advances to next row + verifyDefaultRow(rs, 1); + assertTrue(rs.previous()); // Go back to first row + verifyDefaultRow(rs, 0); + assertFalse(rs.previous()); + assertTrue(rs.absolute(2)); // Jump to second row + verifyDefaultRow(rs, 1); + assertTrue(rs.first()); // go to first row + verifyDefaultRow(rs, 0); + assertEquals(1, rs.getRow()); + assertTrue(rs.last()); // go to last row + verifyDefaultRow(rs, 1); + assertEquals(2, rs.getRow()); + } + + @Test + void test_get_special_bigDecimal() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 9); + when(mockResultSet.getObject(1)).thenReturn( + 12450.567, + -132.45, + "142.346", + "invalid", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + CachedResultSet rs = new CachedResultSet(mockResultSet); + + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("12450.567"))); + + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("-132.45"))); + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("142.346"))); + assertTrue(rs.next()); + try { + rs.getBigDecimal(1); + fail("Invalid value should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Value is null + assertTrue(rs.next()); + assertNull(rs.getBigDecimal(1)); + } + + @Test + void test_get_special_timestamp() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 12); + when(mockResultSet.getObject(1)).thenReturn( + 1504844311000L, + LocalDateTime.of(1981, 3, 10, 1, 10, 20), + OffsetDateTime.parse("2025-08-10T10:00:00+03:00"), + ZonedDateTime.parse("2024-07-30T10:00:00+02:00[Europe/Berlin]"), + "2015-03-15 12:50:04", + "invalidDateTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Timestamp from a number + assertTrue(cachedRs.next()); + assertEquals(new Timestamp(1504844311000L), cachedRs.getTimestamp(1)); + // Timestamp from LocalDateTime + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("1981-03-10 01:10:20"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("1981-03-09 22:10:20"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from OffsetDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestmap from ZonedDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTimestamp(1); + fail("Invalid timestamp should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Timestamp is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTimestamp(1)); + } + + @Test + void test_get_special_time() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 11); + when(mockResultSet.getObject(1)).thenReturn( + 4362000L, + LocalTime.of(10, 20, 30), + OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) + new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()), // Future Date: next year same date at 9:30 AM + "15:34:20", + "InvalidTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, + true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Time from a number + assertTrue(cachedRs.next()); + assertEquals(new Time(4362000L), cachedRs.getTime(1)); + // Time from LocalTime + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("10:20:30"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("07:20:30"), cachedRs.getTime(1, estCal)); + // Time from OffsetTime + assertTrue(cachedRs.next()); + // OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC) converted to default timezone + OffsetTime offsetTimeUtc = OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC); + ZonedDateTime expectedDefaultTz = offsetTimeUtc.atDate(LocalDate.now()) + .atZoneSameInstant(defaultTimeZone.toZoneId()); + assertEquals(Time.valueOf(expectedDefaultTz.toLocalTime()), cachedRs.getTime(1)); + // OffsetTime converted to EST timezone + ZonedDateTime expectedEstTz = offsetTimeUtc.atDate(LocalDate.now()) + .atZoneSameInstant(estCal.getTimeZone().toZoneId()); + assertEquals(Time.valueOf(expectedEstTz.toLocalTime()), cachedRs.getTime(1, estCal)); + // Time from Timestamp + assertTrue(cachedRs.next()); + Timestamp timestampOne = new Timestamp(1755621000000L); + // Compare underlying millis + assertEquals(timestampOne.getTime(), cachedRs.getTime(1).getTime()); + // Compare logical wall-clock time + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Time from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp timestampTwo = new Timestamp(1735713000000L); + assertEquals(timestampTwo.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(22, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(19, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Epoch time of 0 + assertTrue(cachedRs.next()); + assertEquals(new Time(0), cachedRs.getTime(1)); + assertEquals(0L, cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(16, 0, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(13, 0, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Future date + assertTrue(cachedRs.next()); + Timestamp futureTimestamp = new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()); + assertEquals(futureTimestamp.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTime(1); + fail("Invalid time should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Time is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTime(1)); + } + + @Test + void test_get_special_date() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 10); + when(mockResultSet.getObject(1)).thenReturn( + 1515944311000L, + -1000000000L, + LocalDate.of(2010, 10, 30), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(1755673200000L), // Date and time (GMT): Wednesday, August 20, 2025 7:00:00 AM --> PDT Aug 20 12AM + new Timestamp(1735718400000L), // Date and time (GMT): Wednesday, January 1, 2025 8:00:00 AM --> PST Jan 1 12AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) + "2025-03-15", + "InvalidDate", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, + true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Date from a number + assertTrue(cachedRs.next()); + Date date = cachedRs.getDate(1); + assertEquals(new Date(1515944311000L), date); + assertTrue(cachedRs.next()); + assertEquals(new Date(-1000000000L), cachedRs.getDate(1)); + // Date from LocalDate + + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2010-10-30"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2010-10-29"), cachedRs.getDate(1, estCal)); + // Date from Timestamp + assertTrue(cachedRs.next()); + Timestamp tsForDate1 = new Timestamp(1755621000000L); + assertEquals(new Date(tsForDate1.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate2 = new Timestamp(1735713000000L); + assertEquals(new Date(tsForDate2.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1, estCal).toLocalDate()); + // Date from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp tsForDate3 = new Timestamp(1755673200000L); + assertEquals(new Date(tsForDate3.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,8,20), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025,8,19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate4 = new Timestamp(1735718400000L); + assertEquals(new Date(tsForDate4.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,1,1), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024,12,31), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate5 = new Timestamp(0L); + assertEquals(new Date(tsForDate5.getTime()), cachedRs.getDate(1)); + assertEquals(new Date(0L), cachedRs.getDate(1)); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1, estCal).toLocalDate()); + // Date from String + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getDate(1); + fail("Invalid date should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Date is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getDate(1)); + } + + @Test + void test_get_nstring() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + when(mockResultSet.getObject(1)).thenReturn("test string", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test string value - both index and label versions + assertTrue(cachedRs.next()); + assertEquals("test string", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + assertEquals("test string", cachedRs.getNString("fieldString")); + assertFalse(cachedRs.wasNull()); + + // Test number conversion + assertTrue(cachedRs.next()); + assertEquals("123", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getNString(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_bytes() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 4); + // Test data + byte[] testBytes = {1, 2, 3, 4, 5}; + when(mockResultSet.getObject(1)).thenReturn(testBytes, "not bytes", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test bytes values - both index and label versions + assertTrue(cachedRs.next()); + assertArrayEquals(testBytes, cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + assertArrayEquals(testBytes, cachedRs.getBytes("fieldByte")); + assertFalse(cachedRs.wasNull()); + + // Test non-byte array input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("not bytes".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test number input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("123".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getBytes(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_boolean() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 3); + // Test data: boolean, numbers, strings, null + when(mockResultSet.getObject(1)).thenReturn( + true, false, 0, 1, -5, "true", "false", "invalid", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual boolean values - both index and label versions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + assertTrue(cachedRs.getBoolean("fieldBoolean")); + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + + // Test number conversions: 0 = true, non-zero = false + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // 0 → false + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // 1 → true + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // -5 → true + + // Test string conversions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // "true" → true + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "false" → false + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "invalid" → false (parseBoolean) + + // Test null handling + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // null → false + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_URL() throws SQLException { + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: URL object, valid URL string, invalid URL string, null + // URL object setup + URL testUrl = null; + try { + testUrl = new URL("https://example.com"); + } catch (MalformedURLException e) { + fail("Test setup failed"); + } + + when(mockResultSet.getObject(1)).thenReturn( + testUrl, "https://valid.com", "invalid-url", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual URL object - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testUrl, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + assertEquals(testUrl, cachedRs.getURL("fieldString")); + + // Test valid URL string conversion + assertTrue(cachedRs.next()); + URL validURL = null; + try { + validURL = new URL("https://valid.com"); + } catch (MalformedURLException e) { + fail("Failed setting up new valid URL"); + } + assertEquals(validURL, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + + // Test invalid URL string (should throw SQLException) + assertTrue(cachedRs.next()); + assertThrows(SQLException.class, () -> cachedRs.getURL(1)); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getURL(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_sql_xml() throws SQLException { + String longXml = + "\n" + + " TechCorp\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + ""; + SQLXML testXml = new CachedSQLXML("PostgreSQL GuideJohn Doe"); + SQLXML testXml2 = new CachedSQLXML(longXml); + SQLXML invalidXml = new CachedSQLXML("A"); + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 13); + when(mockResultSet.getObject(1)).thenReturn(testXml, testXml2, invalidXml, "invalid-xml", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual SQLXML objects - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals("invalid-xml", cachedRs.getSQLXML(1).getString()); + assertEquals("invalid-xml", cachedRs.getSQLXML("fieldSqlXml").getString()); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertNull(cachedRs.getSQLXML(1)); + assertTrue(cachedRs.wasNull()); + assertNull(cachedRs.getSQLXML("fieldSqlXml")); + assertTrue(cachedRs.wasNull()); + + assertFalse(cachedRs.next()); + } + + + @Test + void test_get_object_with_index_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, null + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject(1, String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject(1, Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject(1, Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject(1, String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_get_object_with_label_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, HashSet (unsupported type), null + HashSet testSet = new HashSet<>(); + testSet.add("item1"); + testSet.add("item2"); + + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, testSet, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject("fieldString", String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject("fieldString", Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject("fieldString", Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test unsupported data type (HashSet) - should work with getObject() + assertTrue(cachedRs.next()); + HashSet retrievedSet = cachedRs.getObject("fieldString", HashSet.class); + assertEquals(testSet, retrievedSet); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject("fieldString", String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_unwrap() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid unwrap to ResultSet interface + ResultSet unwrappedResultSet = cachedRs.unwrap(ResultSet.class); + assertSame(cachedRs, unwrappedResultSet); + + // Test valid unwrap to CachedResultSet class + CachedResultSet unwrappedCachedResultSet = cachedRs.unwrap(CachedResultSet.class); + assertSame(cachedRs, unwrappedCachedResultSet); + + // Test invalid unwrap attempts should throw SQLException + assertThrows(SQLException.class, () -> cachedRs.unwrap(String.class)); + assertThrows(SQLException.class, () -> cachedRs.unwrap(Integer.class)); + } + + @Test + void test_is_wrapper_for() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid wrapper checks + assertTrue(cachedRs.isWrapperFor(ResultSet.class)); + assertTrue(cachedRs.isWrapperFor(CachedResultSet.class)); + + // Test invalid wrapper checks + assertFalse(cachedRs.isWrapperFor(String.class)); + assertFalse(cachedRs.isWrapperFor(Integer.class)); + + // Test null class parameter + assertFalse(cachedRs.isWrapperFor(null)); + } +} + diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java new file mode 100644 index 000000000..7340ac1b0 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java @@ -0,0 +1,174 @@ +package software.amazon.jdbc.plugin.cache; + +import org.junit.jupiter.api.Test; +import org.w3c.dom.*; +import org.xml.sax.Attributes; +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.DefaultHandler; +import java.io.InputStream; +import java.io.Reader; +import java.sql.SQLException; +import java.sql.SQLXML; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; + +import static org.junit.jupiter.api.Assertions.*; + +public class CachedSQLXMLTest { + + @Test + void test_basic_XML() throws Exception { + String xml = "Value AValue B"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // Test binary stream + byte[] array = new byte[100]; + InputStream stream = sqlxml.getBinaryStream(); + assertEquals(xml.length(), stream.available()); + assertTrue(stream.read(array) > 0); + assertEquals(xml, new String(array, 0, xml.length())); + stream.close(); + + // Test character stream + char[] chars = new char[100]; + Reader reader = sqlxml.getCharacterStream(); + assertTrue(reader.read(chars) > 0); + assertEquals(xml, new String(chars, 0, xml.length())); + reader.close(); + + // Test free() + sqlxml.free(); + assertThrows(SQLException.class, sqlxml::getString); + assertThrows(SQLException.class, sqlxml::getCharacterStream); + assertThrows(SQLException.class, sqlxml::getBinaryStream); + assertThrows(SQLException.class, () -> sqlxml.getSource(DOMSource.class)); + } + + private void validateDOMElement(Document document, String elementName, String elementValue) { + NodeList elements = document.getElementsByTagName(elementName); + assertEquals(1, elements.getLength()); + Element element = (Element) elements.item(0); + assertEquals(elementName, element.getNodeName()); + assertEquals(elementValue, element.getTextContent()); + } + + private void validateSimpleDocument(Document document) { + Element rootElement = document.getDocumentElement(); + assertEquals("product", rootElement.getNodeName()); + NodeList elements = document.getElementsByTagName("product"); + assertEquals(1, elements.getLength()); // product has 3 elements + elements = document.getElementsByTagName("specs"); + assertEquals(1, elements.getLength()); // specs has 3 elements + validateDOMElement(document, "manufacturer", "TechCorp"); + validateDOMElement(document, "cpu", "Intel i7"); + validateDOMElement(document, "ram", "16GB"); + validateDOMElement(document, "storage", "512GB SSD"); + validateDOMElement(document, "price", "1200.00"); + } + + static private void validateDocElements(String name, String value) { + if (name.equalsIgnoreCase("manufacturer")) { + assertEquals("TechCorp", value); + } else if (name.equalsIgnoreCase("cpu")) { + assertEquals("Intel i7", value); + } else if (name.equalsIgnoreCase("ram")) { + assertEquals("16GB", value); + } else if (name.equalsIgnoreCase("storage")) { + assertEquals("512GB SSD", value); + } else if (name.equalsIgnoreCase("price")) { + assertEquals("1200.00", value); + } + } + + static private class XmlReaderContentHandler extends DefaultHandler { + private StringBuilder currentValue; + + @Override + public void startElement(String uri, String localName, String qName, Attributes attributes) { + currentValue = new StringBuilder(); // Reset for each new element + } + + @Override + public void endElement(String uri, String localName, String qName) { + // Verify the element's value + String value = currentValue.toString().trim(); + validateDocElements(qName, value); + } + + @Override + public void characters(char[] ch, int start, int length) { + currentValue.append(ch, start, length); + } + } + + @Test + void test_getSource_XML() throws Exception { + // Test parsing a more complex XML via getSource() + String xml = " \n" + + "\n" + + " TechCorp\n\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + "\n"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // DOM source + DOMSource domSource = sqlxml.getSource(null); + Node node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + domSource = sqlxml.getSource(DOMSource.class); + node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + + // SAX source + SAXSource src = sqlxml.getSource(SAXSource.class); + XMLReader xmlReader = src.getXMLReader(); + xmlReader.setContentHandler(new XmlReaderContentHandler()); + xmlReader.parse(src.getInputSource()); + + // Streams source + StreamSource xmlSource = sqlxml.getSource(StreamSource.class); + DocumentBuilder db = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + Document doc = db.parse(new InputSource(xmlSource.getReader())); + doc.getDocumentElement().normalize(); + validateSimpleDocument(doc); + + // StAX Source + StAXSource staxSource = sqlxml.getSource(StAXSource.class); + XMLStreamReader sReader = staxSource.getXMLStreamReader(); + String elementName = ""; + StringBuilder elementValue = new StringBuilder(); + while (sReader.hasNext()) { + int event = sReader.next(); + if (event == XMLStreamReader.START_ELEMENT) { + elementName = sReader.getLocalName(); + } else if (event == XMLStreamReader.CHARACTERS) { + elementValue.append(sReader.getText()); + } else if (event == XMLStreamReader.END_ELEMENT) { + validateDocElements(elementName, elementValue.toString().trim()); + elementName = ""; + elementValue = new StringBuilder(); + } + } + sReader.close(); // Close the reader when done + + // Invalid source class + assertThrows(SQLException.class, () -> sqlxml.getSource(Source.class)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java similarity index 89% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java index 46e27337f..cd307d2bf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.anyString; @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; -class DataCacheConnectionPluginTest { +class DataLocalCacheConnectionPluginTest { private static final Properties props = new Properties(); @@ -55,8 +55,8 @@ class DataCacheConnectionPluginTest { @BeforeEach void setUp() throws SQLException { closeable = MockitoAnnotations.openMocks(this); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); - DataCacheConnectionPlugin.clearCache(); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); + DataLocalCacheConnectionPlugin.clearCache(); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); @@ -82,7 +82,7 @@ void cleanUp() throws Exception { void test_execute_withEmptyCache() throws SQLException { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); final ResultSet rs = plugin.execute( ResultSet.class, @@ -99,7 +99,7 @@ void test_execute_withEmptyCache() throws SQLException { void test_execute_withCache() throws Exception { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); when(mockCallable.call()).thenReturn(mockResult1, mockResult2); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java new file mode 100644 index 000000000..54e3802a2 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -0,0 +1,665 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.*; +import java.util.Optional; +import java.util.Properties; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class DataRemoteCachePluginTest { + private Properties props; + private final String methodName = "Statement.executeQuery"; + private AutoCloseable closeable; + + private DataRemoteCachePlugin plugin; + @Mock PluginService mockPluginService; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TelemetryCounter mockCacheHitCounter; + @Mock TelemetryCounter mockCacheMissCounter; + @Mock TelemetryCounter mockTotalQueryCounter; + @Mock TelemetryCounter mockMalformedHintCounter; + @Mock TelemetryCounter mockCacheBypassCounter; + @Mock TelemetryContext mockTelemetryContext; + @Mock ResultSet mockResult1; + @Mock Statement mockStatement; + @Mock PreparedStatement mockPreparedStatement; + @Mock ResultSetMetaData mockMetaData; + @Mock Connection mockConnection; + @Mock SessionStateService mockSessionStateService; + @Mock DatabaseMetaData mockDbMetadata; + @Mock CacheConnection mockCacheConn; + @Mock JdbcCallable mockCallable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter("JdbcCachedQueryCount")).thenReturn(mockCacheHitCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMissCount")).thenReturn(mockCacheMissCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheTotalQueryCount")).thenReturn(mockTotalQueryCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMalformedQueryHint")).thenReturn(mockMalformedHintCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheBypassCount")).thenReturn(mockCacheBypassCounter); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockResult1.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getColumnCount()).thenReturn(1); + when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void test_getTTLFromQueryHint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Null and empty query hint content are not cacheable + assertNull(plugin.getTtlForQuery(null)); + assertNull(plugin.getTtlForQuery("")); + assertNull(plugin.getTtlForQuery(" ")); + // Valid CACHE_PARAM cases - these are the hint contents after /*+ and before */ + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s)")); + assertEquals(100, plugin.getTtlForQuery("CACHE_PARAM(ttl=100s)")); + assertEquals(35, plugin.getTtlForQuery("CACHE_PARAM(ttl=35s)")); + + // Case insensitive + assertEquals(200, plugin.getTtlForQuery("cache_param(ttl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150s)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(tTl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150S)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(TTL=200S)")); + + // CACHE_PARAM anywhere in hint content (mixed with other hint directives) + assertEquals(250, plugin.getTtlForQuery("INDEX(table1 idx1) CACHE_PARAM(ttl=250s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM(ttl=200s) USE_NL(t1 t2)")); + assertEquals(180, plugin.getTtlForQuery("FIRST_ROWS(10) CACHE_PARAM(ttl=180s) PARALLEL(4)")); + assertEquals(200, plugin.getTtlForQuery("foo=bar,CACHE_PARAM(ttl=200s),baz=qux")); + + // Whitespace handling + assertEquals(400, plugin.getTtlForQuery("CACHE_PARAM( ttl=400s )")); + assertEquals(500, plugin.getTtlForQuery("CACHE_PARAM(ttl = 500s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM( ttl = 200s , key = test )")); + + // Invalid cases - no CACHE_PARAM in hint content + assertNull(plugin.getTtlForQuery("INDEX(table1 idx1)")); + assertNull(plugin.getTtlForQuery("FIRST_ROWS(100)")); + assertNull(plugin.getTtlForQuery("cachettl=300s")); // old format + assertNull(plugin.getTtlForQuery("NO_CACHE")); + + // Missing parentheses + assertNull(plugin.getTtlForQuery("CACHE_PARAM ttl=300s")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300s")); + + // Multiple parameters (future-proofing) + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s, key=test)")); + + // Large TTL values should work + assertEquals(999999, plugin.getTtlForQuery("CACHE_PARAM(ttl=999999s)")); + assertEquals(86400, plugin.getTtlForQuery("CACHE_PARAM(ttl=86400s)")); // 24 hours + } + + @Test + void test_getTTLFromQueryHint_MalformedHints() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Test malformed cases + assertNull(plugin.getTtlForQuery("CACHE_PARAM()")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=abc)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300)")); // missing 's' + + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(invalid_format)")); + + // Invalid TTL values (negative and zero) does not count toward malformed hints + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=0s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-10s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-1s)")); + + // Verify counter was incremented 8 times (5 original + 3 new) + verify(mockMalformedHintCounter, times(5)).inc(); + } + + @Test + void test_execute_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"select * from mytable where ID = 2"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockPluginService).isInTransaction(); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_emptyQuery_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockPluginService).isInTransaction(); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_emptyPreparedStatement_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockPreparedStatement.toString()).thenReturn("", null); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1", "bar1", "bar1"); + compareResults(mockResult1, rs); + + rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + // Mock result set containing 1 row + compareResults(mockResult1, rs); + + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCallable, times(2)).call(); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(2)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory, times(2)).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + //verify(mockPreparedStatement, times(2)).toString(); + verify(mockTelemetryContext, times(2)).closeContext(); + } + + @Test + void test_execute_noCachingLongQuery() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T " + RandomStringUtils.randomAlphanumeric(16350)}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_cachingMissAndHit() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql")); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("mysql"); + when(mockConnection.getSchema()).thenReturn(null); + when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1"); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet); + + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); + + assertTrue(rs2.next()); + assertEquals("bar1", rs2.getString("fooName")); + assertFalse(rs2.next()); + verify(mockPluginService, times(3)).getCurrentConnection(); + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A"); + verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getCatalog(); + verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setCatalog("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(1)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + // First call: Cache miss + Database call + verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true)) + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(3)).closeContext(); + } + + @Test + void test_cachingMissAndHit_preparedStatement() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is a cache miss + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql")); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("mysql"); + when(mockConnection.getSchema()).thenReturn(null); + when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1"); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + when(mockPreparedStatement.toString()).thenReturn("/* CACHE_PARAM(ttl=50s) */ select * from A"); + + // Now query is a cache hit + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet); + + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + + assertTrue(rs2.next()); + assertEquals("bar1", rs2.getString("fooName")); + assertFalse(rs2.next()); + verify(mockPluginService, times(3)).getCurrentConnection(); + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A"); + verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getCatalog(); + verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setCatalog("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(1)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + // First call: Cache miss + Database call + verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true)) + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(3)).closeContext(); + } + + @Test + void test_transaction_cacheQuery() throws Exception { + props.setProperty("user", "dbuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("postgres"); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s) */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getSchema(); + verify(mockSessionStateService).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("public"); + verify(mockSessionStateService).setCatalog("postgres"); + verify(mockDbMetadata, never()).getUserName(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("postgres_public_dbuser_select * from T"), any(), eq(300)); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_transaction_cacheQuery_multiple_query_params() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockDbMetadata.getUserName()).thenReturn("dbuser"); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockConnection.getSchema()).thenReturn("mysql"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s, otherParam=abc) */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getCatalog(); + verify(mockSessionStateService).getSchema(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("null_mysql_dbuser_select * from T"), any(), eq(300)); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_transaction_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockCallable.call()).thenReturn(mockResult1); + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + "Statement.execute", mockCallable, new String[]{"delete from mytable"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_malformed_hint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Setup - not in transaction with malformed cache hint + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + // Query with malformed cache hint - should increment both malformed and bypass counters + String queryWithMalformedHint = "/*+ CACHE_PARAM(ttl=invalid) */ SELECT * FROM users WHERE id = 123"; + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{queryWithMalformedHint}); + // Verify malformed counter incremented first + verify(mockMalformedHintCounter, times(1)).inc(); + // Verify bypass counter incremented (because configuredQueryTtl becomes null) + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_double_bypass_prevention() throws Exception { + props.setProperty("user", "testuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Setup - query that meets MULTIPLE bypass conditions + when(mockPluginService.isInTransaction()).thenReturn(true); // Bypass condition #1 + when(mockCallable.call()).thenReturn(mockResult1); + + // Mock result set for caching + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("testdata"); + + // Query that is BOTH too large AND in transaction - double bypass conditions + String largeQueryInTransaction = "/*+ CACHE_PARAM(ttl=300s) */ SELECT * FROM table WHERE data = '" + + RandomStringUtils.randomAlphanumeric(16384) + "'"; // >16KB AND in transaction + + // Execute + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{largeQueryInTransaction}); + + // Verify bypass counter incremented EXACTLY ONCE (not twice) + verify(mockCacheBypassCounter, times(1)).inc(); + + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + + // Verify malformed counter not called (hint is valid, just large query) + verify(mockMalformedHintCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_execute_multipleCacheHits() throws Exception { + props.setProperty("user", "user"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(serializedTestResultSet); + + for (int i = 0; i < 10; i ++) { + ResultSet cur_rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); + + assertTrue(cur_rs.next()); + assertEquals("bar1", cur_rs.getString("fooName")); + assertFalse(cur_rs.next()); + } + + verify(mockPluginService, times(12)).getCurrentConnection(); + verify(mockPluginService, times(11)).isInTransaction(); + verify(mockCacheConn, times(11)).readFromCache("null_public_user_select * from A"); + verify(mockPluginService, times(12)).getSessionStateService(); + verify(mockSessionStateService, times(12)).getCatalog(); + verify(mockSessionStateService, times(12)).getSchema(); + verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("public"); + verify(mockDbMetadata, never()).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("null_public_user_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(11)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(10)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + verify(mockTelemetryFactory, times(11)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(10)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(12)).closeContext(); + } + + void compareResults(final ResultSet expected, final ResultSet actual) throws SQLException { + int i = 1; + while (expected.next() && actual.next()) { + assertEquals(expected.getObject(i), actual.getObject(i)); + i++; + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java new file mode 100644 index 000000000..4eeb24e19 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java @@ -0,0 +1,193 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.regions.Region; + +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; + +public class ElastiCacheIamTokenUtilityTest { + @Mock private AwsCredentialsProvider mockCredentialsProvider; + @Mock private AwsCredentials mockCredentials; + @Mock private Aws4Signer mockSigner; + @Mock private SdkHttpFullRequest mockSignedRequest; + + private AutoCloseable closeable; + private ElastiCacheIamTokenUtility tokenUtility; + private final Instant fixedInstant = Instant.parse("2025-01-01T12:00:00Z"); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + void testConstructor_WithCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache"); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_WithCacheNameAndFixedInstant() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_NullCacheName() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null)); + } + + @Test + void testConstructor_NullCacheNameWithInstant() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null, fixedInstant, mockSigner)); + } + + @Test + void testGenerateAuthenticationToken_RegularCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertFalse(request.rawQueryParameters().containsKey("ResourceType")); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_ServerlessCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertEquals("ServerlessCache", request.rawQueryParameters().get("ResourceType").get(0)); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache.serverless.cache.amazonaws.com/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.serverless.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache.serverless.cache.amazonaws.com/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_NullCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + // Use reflection to set cacheName to null to test the validation + try { + java.lang.reflect.Field field = ElastiCacheIamTokenUtility.class.getDeclaredField("cacheName"); + field.setAccessible(true); + field.set(tokenUtility, null); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-host", 6379, "testuser")); + } catch (Exception e) { + fail("Reflection failed: " + e.getMessage()); + } + } + + @Test + void testGenerateAuthenticationToken_NullHostname() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, null, 6379, "testuser")); + } +}