diff --git a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceBuilder.java b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceBuilder.java index 3a53a25..a508c6f 100644 --- a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceBuilder.java +++ b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceBuilder.java @@ -379,6 +379,11 @@ default DataSourceBuilder listener(DataSourcePoolListener listener) { @Deprecated(forRemoval = true) DataSourceBuilder setListener(DataSourcePoolListener listener); + /** + * Set the connection initializer to use. + */ + DataSourceBuilder connectionInitializer(NewConnectionInitializer connectionListener); + /** * Set a SQL statement used to test the database is accessible. *

@@ -933,6 +938,11 @@ default String driverClassName() { */ DataSourcePoolListener getListener(); + /** + * Return the new connection listener to use. + */ + NewConnectionInitializer getConnectionInitializer(); + /** * Return a SQL statement used to test the database is accessible. *

diff --git a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConfig.java b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConfig.java index ac06b3e..83068a0 100644 --- a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConfig.java +++ b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConfig.java @@ -84,6 +84,7 @@ public class DataSourceConfig implements DataSourceBuilder.Settings { private List initSql; private DataSourceAlert alert; private DataSourcePoolListener listener; + private NewConnectionInitializer connectionInitializer; private Properties clientInfo; private String applicationName; private boolean shutdownOnJvmExit; @@ -477,6 +478,17 @@ public DataSourceConfig setListener(DataSourcePoolListener listener) { return this; } + @Override + public NewConnectionInitializer getConnectionInitializer() { + return connectionInitializer; + } + + @Override + public DataSourceBuilder connectionInitializer(NewConnectionInitializer connectionInitializer) { + this.connectionInitializer = connectionInitializer; + return this; + } + @Override public String getHeartbeatSql() { return heartbeatSql; diff --git a/ebean-datasource-api/src/main/java/io/ebean/datasource/NewConnectionInitializer.java b/ebean-datasource-api/src/main/java/io/ebean/datasource/NewConnectionInitializer.java new file mode 100644 index 0000000..31990df --- /dev/null +++ b/ebean-datasource-api/src/main/java/io/ebean/datasource/NewConnectionInitializer.java @@ -0,0 +1,26 @@ +package io.ebean.datasource; + +import java.sql.Connection; + +/** + * A {@link DataSourcePool} listener which allows you to hook on the create connections process of the pool. + */ +public interface NewConnectionInitializer { + + /** + * Called after a connection has been created, before any initialization. + * + * @param connection the created connection + */ + default void preInitialize(Connection connection) { + } + + /** + * Called after a connection has been initialized (after onCreatedConnection) and all settings applied. + * + * @param connection the created connection + */ + default void postInitialize(Connection connection) { + } + +} diff --git a/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java b/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java index ce438ea..bcd3c95 100644 --- a/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java +++ b/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java @@ -46,6 +46,7 @@ interface Heartbeat { */ private final DataSourceAlert notify; private final DataSourcePoolListener poolListener; + private final NewConnectionInitializer connectionInitializer; private final List initSql; private final String user; private final String schema; @@ -109,6 +110,7 @@ interface Heartbeat { this.name = name; this.notify = params.getAlert(); this.poolListener = params.getListener(); + this.connectionInitializer = params.getConnectionInitializer(); this.autoCommit = params.isAutoCommit(); this.readOnly = params.isReadOnly(); this.failOnStart = params.isFailOnStart(); @@ -434,6 +436,9 @@ private void testConnection() { * Initializes the connection we got from the driver. */ private Connection initConnection(Connection conn) throws SQLException { + if (connectionInitializer != null) { + connectionInitializer.preInitialize(conn); + } conn.setAutoCommit(autoCommit); // isolation level is set globally for all connections (at least for H2) and // you will need admin rights - so we do not change it, if it already matches. @@ -470,6 +475,9 @@ private Connection initConnection(Connection conn) throws SQLException { } } } + if (connectionInitializer != null) { + connectionInitializer.postInitialize(conn); + } return conn; } diff --git a/ebean-datasource/src/test/java/io/ebean/datasource/pool/ConnectionPoolNewConnectionListenerTest.java b/ebean-datasource/src/test/java/io/ebean/datasource/pool/ConnectionPoolNewConnectionListenerTest.java new file mode 100644 index 0000000..2c71087 --- /dev/null +++ b/ebean-datasource/src/test/java/io/ebean/datasource/pool/ConnectionPoolNewConnectionListenerTest.java @@ -0,0 +1,113 @@ +package io.ebean.datasource.pool; + +import io.ebean.datasource.DataSourceConfig; +import io.ebean.datasource.NewConnectionInitializer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.util.HashMap; +import static org.junit.jupiter.api.Assertions.*; + +class ConnectionPoolNewConnectionListenerTest { + + private ConnectionPool pool; + + private final HashMap createdConnections = new HashMap<>(); + private final HashMap afterConnections = new HashMap<>(); + + ConnectionPoolNewConnectionListenerTest() { + pool = createPool(); + } + + + private ConnectionPool createPool() { + + DataSourceConfig config = new DataSourceConfig(); + config.setDriver("org.h2.Driver"); + config.setUrl("jdbc:h2:mem:tests"); + config.setUsername("sa"); + config.setPassword(""); + config.setMinConnections(1); + config.setMaxConnections(5); + config.connectionInitializer(new NewConnectionInitializer() { + @Override + public void preInitialize(Connection connection) { + synchronized (createdConnections) { + createdConnections.put(connection, 1 + createdConnections.getOrDefault(connection, 0)); + createdConnections.notifyAll(); + } + } + + @Override + public void postInitialize(Connection connection) { + synchronized (afterConnections) { + afterConnections.put(connection, 1 + afterConnections.getOrDefault(connection, 0)); + afterConnections.notifyAll(); + } + } + }); + + return new ConnectionPool("initialize", config); + } + + @AfterEach + public void after() { + pool.shutdown(); + } + + @Test + public void initializeNewConnectionTest() { + // Min connections is 1 so one should be created on pool initialization + synchronized (createdConnections) { + assertEquals(1, createdConnections.size()); + assertEquals(1, afterConnections.size()); + } + + try (Connection connection = pool.getConnection()) { + assertNotNull(connection); + synchronized (createdConnections) { + assertEquals(1, createdConnections.size()); + } + synchronized (afterConnections) { + assertEquals(1, afterConnections.size()); + } + } catch (Exception e) { + fail(e.getMessage()); + } + + try (Connection connection = pool.getConnection()) { + assertNotNull(connection); + synchronized (createdConnections) { + assertEquals(1, createdConnections.size()); + } + synchronized (afterConnections) { + assertEquals(1, afterConnections.size()); + } + + try (Connection connection2 = pool.getConnection()) { + assertNotNull(connection2); + synchronized (createdConnections) { + assertEquals(2, createdConnections.size()); + } + synchronized (afterConnections) { + assertEquals(2, afterConnections.size()); + } + } + } catch (Exception e) { + fail(e.getMessage()); + } + synchronized (createdConnections) { + for (var entry : createdConnections.entrySet()) { + // It should be always 1 + assertEquals(1, entry.getValue()); + } + } + synchronized (afterConnections) { + for (var entry : afterConnections.entrySet()) { + // It should be always 1 + assertEquals(1, entry.getValue()); + } + } + } +}