diff --git a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConnection.java b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConnection.java new file mode 100644 index 0000000..4e11b52 --- /dev/null +++ b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourceConnection.java @@ -0,0 +1,11 @@ +package io.ebean.datasource; + +import java.sql.Connection; + +/** + * @author Roland Praml, Foconis Analytics GmbH + */ +public interface DataSourceConnection extends Connection { + + void clearPreparedStatementCache(); +} diff --git a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourcePoolListener.java b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourcePoolListener.java index 5471c51..cd4cbe5 100644 --- a/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourcePoolListener.java +++ b/ebean-datasource-api/src/main/java/io/ebean/datasource/DataSourcePoolListener.java @@ -1,6 +1,7 @@ package io.ebean.datasource; import java.sql.Connection; +import java.sql.SQLException; /** @@ -17,11 +18,27 @@ public interface DataSourcePoolListener { /** * Called after a connection has been retrieved from the connection pool */ + default void onAfterBorrowConnection(DataSourcePool pool, DataSourceConnection connection) { + onAfterBorrowConnection(connection); + } + + /** + * @deprecated implement {@link #onAfterBorrowConnection(DataSourcePool, DataSourceConnection)} + */ + @Deprecated default void onAfterBorrowConnection(Connection connection) {} /** * Called before a connection will be put back to the connection pool */ + default void onBeforeReturnConnection(DataSourcePool pool, DataSourceConnection connection) { + onBeforeReturnConnection(connection); + } + + /** + * @deprecated implement {@link #onBeforeReturnConnection(DataSourcePool, DataSourceConnection)} + */ + @Deprecated default void onBeforeReturnConnection(Connection connection) {} diff --git a/ebean-datasource/pom.xml b/ebean-datasource/pom.xml index 37909be..de08a50 100644 --- a/ebean-datasource/pom.xml +++ b/ebean-datasource/pom.xml @@ -1,5 +1,6 @@ - + 4.0.0 io.ebean @@ -26,6 +27,13 @@ test + + org.junit.jupiter + junit-jupiter-params + 5.10.2 + test + + io.ebean ebean-test-containers diff --git a/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java b/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java index 2a5ae03..c1992d3 100644 --- a/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java +++ b/ebean-datasource/src/main/java/io/ebean/datasource/pool/ConnectionPool.java @@ -559,7 +559,7 @@ void removeClosedConnection(PooledConnection pooledConnection) { */ private void returnTheConnection(PooledConnection pooledConnection, boolean forceClose) { if (poolListener != null && !forceClose) { - poolListener.onBeforeReturnConnection(pooledConnection); + poolListener.onBeforeReturnConnection(this, pooledConnection); } queue.returnPooledConnection(pooledConnection, forceClose); } @@ -631,7 +631,7 @@ private PooledConnection getPooledConnection() throws SQLException { c.setStackTrace(Thread.currentThread().getStackTrace()); } if (poolListener != null) { - poolListener.onAfterBorrowConnection(c); + poolListener.onAfterBorrowConnection(this, c); } return c; } diff --git a/ebean-datasource/src/main/java/io/ebean/datasource/pool/PooledConnection.java b/ebean-datasource/src/main/java/io/ebean/datasource/pool/PooledConnection.java index 32c6c14..7d0e8e9 100644 --- a/ebean-datasource/src/main/java/io/ebean/datasource/pool/PooledConnection.java +++ b/ebean-datasource/src/main/java/io/ebean/datasource/pool/PooledConnection.java @@ -1,5 +1,7 @@ package io.ebean.datasource.pool; +import io.ebean.datasource.DataSourceConnection; + import java.sql.*; import java.util.ArrayList; import java.util.Map; @@ -17,7 +19,7 @@ * It has caching of Statements and PreparedStatements. Remembers the last * statement that was executed. Keeps statistics on how long it is in use. */ -final class PooledConnection extends ConnectionDelegator { +final class PooledConnection extends ConnectionDelegator implements DataSourceConnection { private static final String IDLE_CONNECTION_ACCESSED_ERROR = "Pooled Connection has been accessed whilst idle in the pool, via method: "; @@ -974,4 +976,13 @@ private String stackTraceAsString(StackTraceElement[] stackTrace) { return filteredList.toString(); } + @Override + public void clearPreparedStatementCache() { + lock.lock(); + try { + pstmtCache.clear(); + } finally { + lock.unlock(); + } + } } diff --git a/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedContextListener.java b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedContextListener.java new file mode 100644 index 0000000..f5b7ae9 --- /dev/null +++ b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedContextListener.java @@ -0,0 +1,38 @@ +package io.ebean.datasource.tcdriver; + +import io.ebean.datasource.DataSourceConnection; +import io.ebean.datasource.DataSourcePool; +import io.ebean.datasource.DataSourcePoolListener; + +import java.sql.SQLException; + +/** + * Listener, that sets up TrustedConnection properly + * + * @author Roland Praml, Foconis Analytics GmbH + */ +public class TrustedContextListener implements DataSourcePoolListener { + private ThreadLocal user = new ThreadLocal<>(); + private ThreadLocal pass = new ThreadLocal<>(); + private ThreadLocal schema = new ThreadLocal<>(); + + @Override + public void onAfterBorrowConnection(DataSourcePool pool, DataSourceConnection connection) { + try { + TrustedDb2Connection trustedDb2Connection = connection.unwrap(TrustedDb2Connection.class); + if (trustedDb2Connection.switchUser(user.get(), pass.get())) { + trustedDb2Connection.setSchema(schema.get()); + connection.clearPreparedStatementCache(); + } + //System.out.println("Switched to " + user.get() + ", Schema: " + schema.get()); + } catch (SQLException e) { + throw new RuntimeException(e); // TODO: Allow throwing sqlException here + } + } + + public void setContext(String user, String pass, String schema) { + this.user.set(user); + this.pass.set(pass); + this.schema.set(schema); + } +} diff --git a/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Connection.java b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Connection.java new file mode 100644 index 0000000..ee0689f --- /dev/null +++ b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Connection.java @@ -0,0 +1,363 @@ +package io.ebean.datasource.tcdriver; + +import com.ibm.db2.jcc.DB2Connection; + +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.ShardingKey; +import java.sql.Statement; +import java.sql.Struct; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.concurrent.Executor; + +/** + * A Wrapper for DB2Connection that holds connection and cookie. TODO: May inherit from ConnectionDelegator later + * + * @author Noemi Praml, Foconis Analytics GmbH + */ +class TrustedDb2Connection implements Connection { + + private final DB2Connection delegate; + + private final byte[] cookie; + private String user; + private String password; + + TrustedDb2Connection(DB2Connection delegate, byte[] cookie) { + this.delegate = delegate; + this.cookie = cookie; + } + + boolean switchUser(String user, String password) throws SQLException { + if (!Objects.equals(user, this.user) || !Objects.equals(password, this.password)) { + // reusing connection destroys all preparedStatements + delegate.reuseDB2Connection(cookie, user, password, null, null, null, new Properties()); + this.user = user; + this.password = password; + return true; + } + return false; + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + if (iface == TrustedDb2Connection.class) { + return true; + } + return delegate.isWrapperFor(iface); + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface == TrustedDb2Connection.class) { + return (T) this; + } + return delegate.unwrap(iface); + } + + @Override + public void setShardingKey(ShardingKey shardingKey) throws SQLException { + delegate.setShardingKey(shardingKey); + } + + @Override + public void setShardingKey(ShardingKey shardingKey, ShardingKey superShardingKey) throws SQLException { + delegate.setShardingKey(shardingKey, superShardingKey); + } + + @Override + public boolean setShardingKeyIfValid(ShardingKey shardingKey, int timeout) throws SQLException { + return delegate.setShardingKeyIfValid(shardingKey, timeout); + } + + @Override + public boolean setShardingKeyIfValid(ShardingKey shardingKey, ShardingKey superShardingKey, int timeout) throws SQLException { + return delegate.setShardingKeyIfValid(shardingKey, superShardingKey, timeout); + } + + @Override + public void endRequest() throws SQLException { + delegate.endRequest(); + } + + @Override + public void beginRequest() throws SQLException { + delegate.beginRequest(); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return delegate.getNetworkTimeout(); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + delegate.setNetworkTimeout(executor, milliseconds); + } + + @Override + public void abort(Executor executor) throws SQLException { + delegate.abort(executor); + } + + @Override + public String getSchema() throws SQLException { + return delegate.getSchema(); + } + + @Override + public void setSchema(String schema) throws SQLException { + delegate.setSchema(schema); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + return delegate.createStruct(typeName, attributes); + } + + @Override + public Properties getClientInfo() throws SQLException { + return delegate.getClientInfo(); + } + + @Override + public String getClientInfo(String name) throws SQLException { + return delegate.getClientInfo(name); + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + delegate.setClientInfo(properties); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + delegate.setClientInfo(name, value); + } + + @Override + public boolean isValid(int timeout) throws SQLException { + return delegate.isValid(timeout); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return delegate.createSQLXML(); + } + + @Override + public NClob createNClob() throws SQLException { + return delegate.createNClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return delegate.createBlob(); + } + + @Override + public Clob createClob() throws SQLException { + return delegate.createClob(); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + return delegate.prepareStatement(sql, columnNames); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + return delegate.prepareStatement(sql, columnIndexes); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + return delegate.prepareStatement(sql, autoGeneratedKeys); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return delegate.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { + return delegate.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + delegate.releaseSavepoint(savepoint); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + delegate.rollback(savepoint); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + return delegate.setSavepoint(name); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return delegate.setSavepoint(); + } + + @Override + public int getHoldability() throws SQLException { + return delegate.getHoldability(); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + delegate.setHoldability(holdability); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + delegate.setTypeMap(map); + } + + @Override + public Map> getTypeMap() throws SQLException { + return delegate.getTypeMap(); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return delegate.prepareCall(sql, resultSetType, resultSetConcurrency); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return delegate.prepareStatement(sql, resultSetType, resultSetConcurrency); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + return delegate.createStatement(resultSetType, resultSetConcurrency); + } + + @Override + public void clearWarnings() throws SQLException { + delegate.clearWarnings(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate.getWarnings(); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return delegate.getTransactionIsolation(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + delegate.setTransactionIsolation(level); + } + + @Override + public String getCatalog() throws SQLException { + return delegate.getCatalog(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + delegate.setCatalog(catalog); + } + + @Override + public boolean isReadOnly() throws SQLException { + return delegate.isReadOnly(); + } + + @Override + public void setReadOnly(boolean readOnly) throws SQLException { + delegate.setReadOnly(readOnly); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return delegate.getMetaData(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate.isClosed(); + } + + @Override + public void close() throws SQLException { + delegate.close(); + } + + @Override + public void rollback() throws SQLException { + delegate.rollback(); + } + + @Override + public void commit() throws SQLException { + delegate.commit(); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return delegate.getAutoCommit(); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + delegate.setAutoCommit(autoCommit); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + return delegate.nativeSQL(sql); + } + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + return delegate.prepareCall(sql); + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + return delegate.prepareStatement(sql); + } + + @Override + public Statement createStatement() throws SQLException { + return delegate.createStatement(); + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + return delegate.createArrayOf(typeName, elements); + } + +} diff --git a/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Driver.java b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Driver.java new file mode 100644 index 0000000..177330c --- /dev/null +++ b/ebean-datasource/src/test/java/io/ebean/datasource/tcdriver/TrustedDb2Driver.java @@ -0,0 +1,147 @@ +package io.ebean.datasource.tcdriver; + +import com.ibm.db2.jcc.DB2Connection; +import com.ibm.db2.jcc.DB2Driver; +import com.ibm.db2.jcc.DB2PooledConnection; + +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Properties; + +/** + * There is no way to create a trusted connection by a JDBC connection string + * (like jdbc:db2:...trusted=true), so this is a simple driver wrapper, that + * allows us to get a DB2 trusted connection via JDBC-API + * (e.g. jdbc:db2trusted://localhost:40005/database:someProp=someValue;) + * + * @author Noemi Praml, Foconis Analytics GmbH + */ +public class TrustedDb2Driver implements Driver { + + private DB2Driver delegate = new DB2Driver(); + + static { + try { + new DB2Driver(); + DriverManager.registerDriver(new TrustedDb2Driver()); + } catch (SQLException e) { + // eat + } + } + + @Override + public Connection connect(String url, Properties info) throws SQLException { + Properties properties = new Properties(); + properties.putAll(info); + + HostPortDb result; + + try { + result = HostPortDb.parse(url, properties); + } catch (URISyntaxException e) { + throw new SQLException("Invalid url: " + url); + } + + com.ibm.db2.jcc.DB2ConnectionPoolDataSource ds1 = + new com.ibm.db2.jcc.DB2ConnectionPoolDataSource(); + ds1.setServerName(result.host); + ds1.setPortNumber(result.port); + ds1.setDatabaseName(result.dbName); + ds1.setDriverType(4); + ds1.setUser(properties.getProperty("user")); + ds1.setPassword(properties.getProperty("password")); + + Object[] objects = ds1.getDB2TrustedPooledConnection(properties.getProperty("user"), properties.getProperty("password"), properties); + DB2PooledConnection pooledCon = (DB2PooledConnection) objects[0]; + byte[] cookie = (byte[]) objects[1]; + + return new TrustedDb2Connection((DB2Connection) pooledCon.getConnection(), cookie); + } + + + @Override + public boolean acceptsURL(String url) throws SQLException { + return (url != null && (url.startsWith("jdbc:db2trusted:"))); + } + + @Override + public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException { + return new DriverPropertyInfo[0]; + } + + @Override + public int getMajorVersion() { + return delegate.getMajorVersion(); + } + + @Override + public int getMinorVersion() { + return delegate.getMinorVersion(); + } + + @Override + public boolean jdbcCompliant() { + return delegate.jdbcCompliant(); + } + + @Override + public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException { + return delegate.getParentLogger(); + } + + + /** + * Helper, that parses the JDBC-URL like jdbc:db2trusted://localhost:40005/migtest:currentSchema=METRICSTASK2; + * in host/port/db (similar! to DB2 syntax) + */ + private static class HostPortDb { + public final String host; + public final int port; + public final String dbName; + + + public HostPortDb(String host, int port, String dbName) { + this.host = host; + this.port = port; + this.dbName = dbName; + } + + static HostPortDb parse(String url, Properties properties) throws URISyntaxException { + assert url.startsWith("jdbc:"); + URI uri = new URI(url.substring(5)); + + String host = uri.getHost(); + int port = uri.getPort(); + if (port == 0) { + port = 50000; + } + + String path = uri.getPath(); + if (path.startsWith("/")) { + path = path.substring(1); + } + int colon = path.indexOf(':'); + String dbName = colon == -1 ? path : path.substring(0, colon); + + + if (colon != -1) { + String propertiesString = path.substring(colon + 1); + + String[] keyValuePairs = propertiesString.split(";"); + for (String pair : keyValuePairs) { + String[] keyValue = pair.split("=", 2); + if (keyValue.length == 2) { + properties.setProperty(keyValue[0].trim(), keyValue[1].trim()); + } + } + } + return new HostPortDb(host, port, dbName); + } + } +} diff --git a/ebean-datasource/src/test/java/io/ebean/datasource/test/Db2TrustedContextTest.java b/ebean-datasource/src/test/java/io/ebean/datasource/test/Db2TrustedContextTest.java new file mode 100644 index 0000000..297e906 --- /dev/null +++ b/ebean-datasource/src/test/java/io/ebean/datasource/test/Db2TrustedContextTest.java @@ -0,0 +1,439 @@ +package io.ebean.datasource.test; + +import io.ebean.datasource.DataSourceBuilder; +import io.ebean.datasource.DataSourcePool; +import io.ebean.datasource.tcdriver.TrustedContextListener; +import io.ebean.datasource.tcdriver.TrustedDb2Driver; +import io.ebean.test.containers.Db2Container; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * This test class shows a competitition between ONE connection pool that uses a DB2 + * trusted context and two connection pools. + */ +@Disabled("DB2 container start is slow - run manually") +class Db2TrustedContextTest { + + private static Db2Container container; + + private static Method dockerSuMethod = getSuMethod(); + + private static TrustedContextListener listener = new TrustedContextListener(); + private static final Random rand = new Random(); + private static List summary = new ArrayList<>(); + + static { + new TrustedDb2Driver(); + } + + /** + * Unfortunately, container.dockerSu is protected. So we use some reflection in the meantime + */ + private static Method getSuMethod() { + try { + Method m = Db2Container.class.getDeclaredMethod("dockerSu", String.class, String.class); + m.setAccessible(true); + return m; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Unfortunately, container.dockerSu is protected. So we use some reflection in the meantime + */ + static void dockerSu(String user, String cmd) { + System.out.println("dockerSu: " + user + ", " + cmd); + try { + List ret = (List) dockerSuMethod.invoke(container, user, cmd); + System.out.println("OK: " + ret); + } catch (InvocationTargetException ite) { + System.err.println("FAIL: " + ite.getCause().getMessage()); + } catch (Exception e) { + System.err.println("FAIL: "); + e.printStackTrace(); + } + } + + /** + * Setup the DB2 docker with trusted context support + */ + @BeforeAll + static void before() throws InvocationTargetException, IllegalAccessException { + container = Db2Container.builder("11.5.6.0a") + .port(55505) + .containerName("trusted_context") + .dbName("unit") + .user("unit") + .password("unit") + // to change collation, charset and other parameters like pagesize: + .createOptions("USING CODESET UTF-8 TERRITORY DE COLLATE USING IDENTITY PAGESIZE 32768") + .configOptions("USING STRING_UNITS SYSTEM") + .build(); + + container.start(); + + //setupTrustedContext("172.16.0.1"); // TODO: This will change per machine! + } + + @AfterAll + static void after() { + //container.stop(); + summary.forEach(System.out::println); + } + + private static void setupTrustedContext(String localDockerIp) { + + dockerSu("root", "useradd webuser"); + dockerSu("root", "useradd tenant1"); + dockerSu("root", "useradd tenant2"); + dockerSu("root", "echo \"webuser:webpass\" | chpasswd"); + dockerSu("root", "echo \"tenant1:pass1\" | chpasswd"); + dockerSu("root", "echo \"tenant2:pass2\" | chpasswd"); + + + dockerSu("admin", "db2 connect to unit;db2 drop trusted context webapptrust"); + dockerSu("admin", "db2 connect to unit;db2 drop table S1.test;db2 drop table S2.test"); + + // Setting up the trusted context + dockerSu("admin", "db2 connect to unit;db2 create trusted context webapptrust based upon connection using system authid webuser attributes \\(address \\'" + localDockerIp + "\\'\\) WITH USE FOR tenant1 WITHOUT AUTHENTICATION, tenant2 WITH AUTHENTICATION ENABLE"); + dockerSu("admin", "db2 connect to unit;db2 create table S1.test \\(id int\\)"); + dockerSu("admin", "db2 connect to unit;db2 insert into S1.test values \\(1\\)"); + dockerSu("admin", "db2 connect to unit;db2 create table S2.test \\(id int\\)"); + dockerSu("admin", "db2 connect to unit;db2 insert into S2.test values \\(2\\)"); + dockerSu("admin", "db2 connect to unit;db2 grant connect on database to user webuser"); + dockerSu("admin", "db2 connect to unit;db2 grant connect on database to user tenant1"); + dockerSu("admin", "db2 connect to unit;db2 grant connect on database to user tenant2"); + + dockerSu("admin", "db2 connect to unit;db2 grant all on schema S1 to user tenant1"); + dockerSu("admin", "db2 connect to unit;db2 grant all on schema S2 to user tenant2"); + } + + + private AtomicInteger successCount = new AtomicInteger(); + private AtomicInteger queryCount = new AtomicInteger(); + private volatile boolean running = true; + + + enum LoadProfile { + /** + * try to do as much as work you can + */ + MAX_LOAD, + /** + * Perform 100 queries and hold connection for 10ms + */ + HOLD, + /** + * After a random delay, acquire 10 connections at once and hold them in total for 1s + */ + BURST; + } + + void doSomeWork(DataSourcePool pool, int tenant, LoadProfile loadProfile) { + listener.setContext("tenant" + tenant, "pass" + tenant, "S" + tenant); + try { + switch (loadProfile) { + case MAX_LOAD: + while (running) { + assertThat(executeQuery(pool, "select * from test")).isEqualTo(tenant); + queryCount.incrementAndGet(); + } + break; + case HOLD: + for (int i = 0; i < 100; i++) { + try (Connection conn = pool.getConnection()) { + Thread.sleep(10); + conn.rollback(); + } + queryCount.incrementAndGet(); + } + break; + case BURST: + int wait = rand.nextInt(1000); + Thread.sleep(wait); + List connections = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + connections.add(pool.getConnection()); + queryCount.incrementAndGet(); + } + Thread.sleep(1000 - wait); + for (Connection connection : connections) { + connection.rollback(); + connection.close(); + } + break; + } + successCount.incrementAndGet(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + Thread createWorkerThreas(DataSourcePool pool, int tenant, LoadProfile loadProfile) { + Thread thread = new Thread(() -> { + doSomeWork(pool, tenant, loadProfile); + }); + return thread; + } + + int executeQuery(DataSourcePool pool, String query) throws Exception { + try (Connection conn = pool.getConnection()) { + try (PreparedStatement pstmt = conn.prepareStatement(query)) { + ResultSet rs = pstmt.executeQuery(); + assertThat(rs.next()).isTrue(); + return rs.getInt(1); + } finally { + conn.rollback(); + } + } + } + + @ParameterizedTest + @MethodSource("testKeys") + void testSwitchWithTrustedContext(TestKey testKey) throws Exception { + + DataSourcePool pool = getPool(testKey.poolSize); + try { + // set tenant of this thread to tenant1 + listener.setContext("tenant1", "pass1", "S1"); + // TestDDL + pool.status(true); + assertThat(executeQuery(pool, "select * from test")).isEqualTo(1); // each tenant must read its own data! + assertThat(executeQuery(pool, "select * from test")).isEqualTo(1); // check cache hit + assertThat(pool.status(false).hitCount()).isEqualTo(2); + + testDdl(pool, 1); + assertThat(executeQuery(pool, "select * from test2")).isEqualTo(1); + + assertThatThrownBy(() -> executeQuery(pool, "select * from S2.test")) + .isInstanceOf(SQLException.class) + .hasMessageContaining("SQLCODE=-551, SQLSTATE=42501, SQLERRMC=TENANT1;SELECT;S2.TEST"); + + listener.setContext("tenant2", "pass2", "S2"); // try again. Same query with + assertThat(executeQuery(pool, "select * from S2.test")).isEqualTo(2); + + testDdl(pool, 2); + assertThat(executeQuery(pool, "select * from test")).isEqualTo(2); + assertThat(executeQuery(pool, "select * from test2")).isEqualTo(2); + + try { + long qps = checkThroughput(pool, pool, testKey.threads, testKey.loadProfile); + summary.add(testKey + ",\t" + qps+",\tswitch"); + } catch (Throwable t) { + summary.add(testKey + ",\tFAIL,\tswitch"); + throw t; + } + // Query per seconds + // Threads | maxConn 5 | maxConn 10 | maxConn 20 + // 1 | 3837 | 3885 | 3904 + // 2 | 5401 | 3900 | 5649 + // 5 | 8991 | 9441 | 8029 + // 10 | 1407 | 12438 | 12187 + // 20 | 1739 | 1825 | 13845 + // 200 | | | 2127 + // on high contention, the switching pool drops massive in performance + + // Query per seconds (with holdConnections) + // Threads | maxConn 5 | maxConn 10 | maxConn 20 + // 1 | 90 | 87 | 84 + // 2 | 179 | 163 | 176 + // 5 | 399 | 450 | 445 + // 10 | 397 | 873 | 773 + // 20 | 386 | 747 | 1673 + // 200 | 373 HBFail | 810 | 1337 + } finally { + pool.shutdown(); + } + } + + + @ParameterizedTest + @MethodSource("testKeys") + void testTwoPools(TestKey testKey) throws Exception { + + DataSourcePool pool1 = getPool1(testKey.poolSize / 2); + DataSourcePool pool2 = getPool2(testKey.poolSize - testKey.poolSize / 2); + try { + // set tenant of this thread to tenant1 + pool1.status(true); + assertThat(executeQuery(pool1, "select * from test")).isEqualTo(1); // each tenant must read its own data! + assertThat(executeQuery(pool1, "select * from test")).isEqualTo(1); // check cache hit + assertThat(pool1.status(false).hitCount()).isEqualTo(2); + + assertThatThrownBy(() -> executeQuery(pool1, "select * from S2.test")) + .isInstanceOf(SQLException.class) + .hasMessageContaining("SQLCODE=-551, SQLSTATE=42501, SQLERRMC=TENANT1;SELECT;S2.TEST"); + + testDdl(pool1, 1); + assertThat(executeQuery(pool1, "select * from test2")).isEqualTo(1); + + assertThat(executeQuery(pool2, "select * from S2.test")).isEqualTo(2); + + testDdl(pool2, 2); + assertThat(executeQuery(pool2, "select * from test")).isEqualTo(2); + assertThat(executeQuery(pool2, "select * from test2")).isEqualTo(2); + + try { + long qps = checkThroughput(pool1, pool2, testKey.threads, testKey.loadProfile); + summary.add(testKey + ",\t" + qps+",\ttwoPools"); + } catch (Throwable t) { + summary.add(testKey + ",\tFAIL,\ttwoPools"); + throw t; + } + // Query per seconds + // Threads | maxConn 2+3 | maxConn 5+5 | maxConn 10+10 + // 1 | 3878 | 3675 | 3899 + // 2 | 6533 | 6601 | 6498 + // 5 | 8883 | 11665 | 11145 + // 10 | 9820 | 18292 | 17891 + // 20 | 10937 | 17742 | 28214 + // 200 | | | 23486 + // even on high contention, dedicated pools provide best performance + + // Query per seconds (with holdConnections) + // Threads | maxConn 2+3 | maxConn 5+5 | maxConn 10+10 + // 1 | 88 | 88 | 87 + // 2 | 180 | 178 | 181 + // 5 | 154 | 452 | 446 + // 10 | 365 | 895 | 891 + // 20 | 362 | 901 | 1796 + // 200 | 366 | 912 | 1832 + } finally { + pool1.shutdown(); + pool2.shutdown(); + } + } + + private static void testDdl(DataSourcePool pool, int value) throws SQLException { + try (Connection conn = pool.getConnection()) { + try (Statement stmt = conn.createStatement()) { + try { + stmt.execute("drop table test2"); + conn.commit(); + } catch (SQLException e) { + // Table did not exist + } + stmt.execute("create table test2 (id int)"); + try (PreparedStatement pstmt = conn.prepareStatement("insert into test2 values (?)")) { + pstmt.setInt(1, value); + pstmt.executeUpdate(); + } + } finally { + conn.commit(); + } + } + } + + + private long checkThroughput(DataSourcePool pool1, DataSourcePool pool2, int threadCount, LoadProfile loadProfile) throws InterruptedException { + successCount.set(0); + long time = System.currentTimeMillis(); + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + int tenant = i % 2 + 1; + threads.add(createWorkerThreas(tenant == 1 ? pool1 : pool2, tenant, loadProfile)); + } + running = true; + for (Thread thread : threads) { + thread.start(); + } + if (loadProfile == LoadProfile.MAX_LOAD) { + Thread.sleep(5000); + } + running = false; + for (Thread thread : threads) { + thread.join(); + } + time = System.currentTimeMillis() - time; + long qps = queryCount.get() * 1000L / time; + System.out.println("Success: " + successCount.get() + ", QPS: " + qps); + System.out.println(pool1.status(false)); + if (pool1 != pool2) { + System.out.println(pool2.status(false)); + } + assertThat(successCount.get()).isEqualTo(threadCount); + return qps; + } + + static List testKeys() { + List keys = new ArrayList<>(); + int[] threadsList = {1, 2, 5, 10, 20, 50, 100}; + int[] poolSizeList = {5, 10, 20, 50}; + for (int threads : threadsList) { + for (int pools : poolSizeList) { + for (LoadProfile profile : LoadProfile.values()) { + if (profile == LoadProfile.BURST && pools < 10) { + continue; // does not make sense + } + keys.add(new TestKey(pools, threads, profile)); + } + } + } + return keys; + } + + static class TestKey { + final int poolSize; + final int threads; + final LoadProfile loadProfile; + + TestKey(int poolSize, int threads, LoadProfile loadProfile) { + this.poolSize = poolSize; + this.threads = threads; + this.loadProfile = loadProfile; + } + + @Override + public String toString() { + return poolSize + ",\t" + threads + ",\t" + loadProfile; + } + } + + private static DataSourcePool getPool(int size) { + return DataSourceBuilder.create() + .url(container.jdbcUrl().replace(":db2:", ":db2trusted:")) + .username("webuser") + .password("webpass") + .maxConnections(size) + .listener(listener) + .build(); + } + + private static DataSourcePool getPool1(int size) { + return DataSourceBuilder.create() + .url(container.jdbcUrl() + ":currentSchema=S1;") + .username("tenant1") + .password("pass1") + .maxConnections(size) + .build(); + } + + private static DataSourcePool getPool2(int size) { + return DataSourceBuilder.create() + .url(container.jdbcUrl() + ":currentSchema=S2;") + .username("tenant2") + .password("pass2") + .maxConnections(size) + .build(); + } +}