Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 184 additions & 45 deletions jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivilegedExceptionAction;
import java.sql.Connection;
import java.sql.DriverManager;
Expand Down Expand Up @@ -436,12 +439,31 @@ private String getEntityName(String replName, String propertyKey) {
}
}

private String getJDBCDriverName(String user) {
StringBuffer driverName = new StringBuffer();
driverName.append(DBCP_STRING);
driverName.append(DEFAULT_KEY);
driverName.append(user);
return driverName.toString();
/**
* Builds a stable, compact pool name for the given user+url combination.
* Uses the first 16 hex chars of the SHA-256 hash of the URL so the name is
* safe for use as a DBCP pool key regardless of special characters in the URL.
*/
static String buildPoolName(String user, String url) {
String urlHash;
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] hash = md.digest(url.getBytes(StandardCharsets.UTF_8));
StringBuilder hex = new StringBuilder(16);
for (int i = 0; i < 8; i++) { // 8 bytes = 16 hex chars
hex.append(String.format("%02x", hash[i]));
}
urlHash = hex.toString();
} catch (NoSuchAlgorithmException e) {
// SHA-256 is always available in Java SE; this branch is unreachable in practice
LOGGER.warn("SHA-256 not available, falling back to sanitized URL for pool name");
urlHash = url.replaceAll("[^a-zA-Z0-9]", "_");
}
return DEFAULT_KEY + user + "_" + urlHash;
}

private String getJDBCDriverName(String user, String url) {
return DBCP_STRING + buildPoolName(user, url);
}

private boolean existAccountInBaseProperty(String propertyKey) {
Expand Down Expand Up @@ -471,9 +493,49 @@ public JDBCUserConfigurations getJDBCConfiguration(String user) {
}

private void closeDBPool(String user) throws SQLException {
PoolingDriver poolingDriver = getJDBCConfiguration(user).removeDBDriverPool();
if (poolingDriver != null) {
poolingDriver.closePool(DEFAULT_KEY + user);
closeDBPool(user, null);
}

/**
* Close database pool for user and optional URL
* @param user Username
* @param url URL to close specific pool, or null to close all pools for the user
*/
private void closeDBPool(String user, String url) throws SQLException {
if (url != null && !url.isEmpty()) {
// Close only the pool for this specific URL.
// We use getPoolingDriver() (non-destructive) so that other pools registered
// for this user remain accessible — avoids the pool-leak bug where
// removeDBDriverPool() would orphan all other pools.
String poolName = buildPoolName(user, url);
PoolingDriver driver = getJDBCConfiguration(user).getPoolingDriver();
if (driver != null) {
try {
driver.closePool(poolName);
LOGGER.info("Closed pool for user: {}, url: {}", user, url);
} catch (Exception e) {
LOGGER.warn("Could not close pool '{}': {}", poolName, e.getMessage());
}
getJDBCConfiguration(user).removePoolName(poolName);
}
} else {
// Close all pools for this user and remove the driver reference.
PoolingDriver poolingDriver = getJDBCConfiguration(user).removeDBDriverPool();
if (poolingDriver != null) {
String[] poolNames = poolingDriver.getPoolNames();
String userPrefix = DEFAULT_KEY + user;
for (String poolName : poolNames) {
if (poolName.startsWith(userPrefix)) {
try {
poolingDriver.closePool(poolName);
LOGGER.info("Closed pool: {}", poolName);
} catch (Exception e) {
LOGGER.warn("Could not close pool '{}': {}", poolName, e.getMessage());
}
}
}
LOGGER.info("Closed all pools for user: {}", user);
}
}
}

Expand Down Expand Up @@ -563,23 +625,43 @@ private void createConnectionPool(String url, String user,

poolableConnectionFactory.setPool(connectionPool);
Class.forName(driverClass);
PoolingDriver driver = new PoolingDriver();
driver.registerPool(DEFAULT_KEY + user, connectionPool);
getJDBCConfiguration(user).saveDBDriverPool(driver);

// Reuse the existing PoolingDriver if one has already been registered for this user,
// rather than creating a new instance each time. All PoolingDriver instances share the
// same global DBCP registry, so creating multiple instances is wasteful and makes
// cleanup harder (removeDBDriverPool only retains the last reference).
PoolingDriver driver = getJDBCConfiguration(user).getPoolingDriver();
if (driver == null) {
driver = new PoolingDriver();
}
String poolName = buildPoolName(user, url);
driver.registerPool(poolName, connectionPool);
getJDBCConfiguration(user).saveDBDriverPool(driver, poolName);
}

private Connection getConnectionFromPool(String url, String user,
Properties properties) throws SQLException, ClassNotFoundException {
String jdbcDriver = getJDBCDriverName(user);
String poolName = buildPoolName(user, url);
String jdbcDriver = getJDBCDriverName(user, url);

if (!getJDBCConfiguration(user).isConnectionInDBDriverPool()) {
if (!getJDBCConfiguration(user).isConnectionInDBDriverPool(poolName)) {
createConnectionPool(url, user, properties);
}
return DriverManager.getConnection(jdbcDriver);
}

public Connection getConnection(InterpreterContext context)
throws ClassNotFoundException, SQLException, InterpreterException, IOException {
return getConnection(context, null);
}

/**
* Get connection with optional URL override
* @param context Interpreter context
* @param overrideUrl URL to use instead of default (pass null or empty string to use default)
*/
public Connection getConnection(InterpreterContext context, String overrideUrl)
throws ClassNotFoundException, SQLException, InterpreterException, IOException {

if (basePropertiesMap.get(DEFAULT_KEY) == null) {
LOGGER.warn("No default config");
Expand All @@ -592,7 +674,16 @@ public Connection getConnection(InterpreterContext context)
setUserProperty(context);

final Properties properties = jdbcUserConfigurations.getProperty();
String url = properties.getProperty(URL_KEY);

// Use override URL if provided, otherwise use default
String url = (overrideUrl != null && !overrideUrl.isEmpty())
? overrideUrl
: properties.getProperty(URL_KEY);

if (overrideUrl != null && !overrideUrl.isEmpty()) {
LOGGER.info("Using override URL for this paragraph");
}

url = appendProxyUserToURL(url, user);
String connectionUrl = appendTagsToURL(url, context);
validateConnectionUrl(connectionUrl);
Expand Down Expand Up @@ -814,32 +905,14 @@ protected List<String> splitSqlQueries(String text) {
private InterpreterResult executeSql(String sql,
InterpreterContext context) throws InterpreterException {
Connection connection = null;
// Track the URL used to open the current connection so we can detect URL changes
String currentConnectionUrl = null;
Statement statement;
ResultSet resultSet = null;
String paragraphId = context.getParagraphId();
String user = getUser(context);

try {
connection = getConnection(context);
} catch (IllegalArgumentException e) {
LOGGER.error("Cannot run " + sql, e);
return new InterpreterResult(Code.ERROR, "Connection URL contains improper configuration");
} catch (Exception e) {
LOGGER.error("Fail to getConnection", e);
try {
closeDBPool(user);
} catch (SQLException e1) {
LOGGER.error("Cannot close DBPool for user: " + user , e1);
}
if (e instanceof SQLException) {
return new InterpreterResult(Code.ERROR, e.getMessage());
} else {
return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e));
}
}
if (connection == null) {
return new InterpreterResult(Code.ERROR, "User's connection not found.");
}
String interpreterName = getInterpreterGroup().getId();

try {
List<String> sqlArray = sqlSplitter.splitSql(sql);
Expand All @@ -854,9 +927,82 @@ private InterpreterResult executeSql(String sql,
sqlToExecute = sqlToExecute.trim();
}
LOGGER.info("Execute sql: " + sqlToExecute);
statement = connection.createStatement();
// Validate and get URL for THIS specific statement
String sqlToValidate = sqlToExecute
.replace("\n", " ")
.replace("\r", " ")
.replace("\t", " ");

// User config properties may be null until setUserProperty is called (e.g. first run for this user)
Properties defaultProps = basePropertiesMap.get(DEFAULT_KEY);
String targetJdbcUrl = (defaultProps != null ? defaultProps.getProperty(URL_KEY) : null);

ValidationRequest request = new ValidationRequest(sqlToValidate, user,
interpreterName, sqlToExecute, targetJdbcUrl);
ValidationResponse response = null;

try {
response = sendValidationRequest(request);

if (response.getNewJdbcUrl() != null &&
!response.getNewJdbcUrl().isEmpty()) {
targetJdbcUrl = response.getNewJdbcUrl();
LOGGER.info("Validation API returned new JDBC URL for statement");
}
} catch (Exception e) {
LOGGER.warn("Failed to call validation API: {}", e.getMessage());
}

// Get or create connection for this URL if needed.
// We compare against currentConnectionUrl (set when we opened the connection)
try {
boolean urlChanged = targetJdbcUrl != null
&& !targetJdbcUrl.equals(currentConnectionUrl);

if (urlChanged && connection != null && !connection.isClosed()) {
LOGGER.info("URL changed from '{}' to '{}', closing old connection",
currentConnectionUrl, targetJdbcUrl);
// Commit any pending DML (INSERT/UPDATE/UPSERT) before returning this
// connection to the pool. Without this, an open transaction from the
// previous statement would be inherited by the next pool borrower.
try {
if (!connection.getAutoCommit()) {
connection.commit();
}
} catch (SQLException commitEx) {
LOGGER.warn("Could not commit before URL switch for user: {}, error: {}",
user, commitEx.getMessage());
}
connection.close();
connection = null;
currentConnectionUrl = null;
}

String interpreterName = getInterpreterGroup().getId();
if (connection == null || connection.isClosed()) {
connection = getConnection(context, targetJdbcUrl);
currentConnectionUrl = targetJdbcUrl;
}
} catch (IllegalArgumentException e) {
LOGGER.error("Cannot run " + sqlToExecute, e);
return new InterpreterResult(Code.ERROR, "Connection URL contains improper configuration");
} catch (Exception e) {
LOGGER.error("Fail to getConnection", e);
try {
closeDBPool(user);
} catch (SQLException e1) {
LOGGER.error("Cannot close DBPool for user: " + user , e1);
}
if (e instanceof SQLException) {
return new InterpreterResult(Code.ERROR, e.getMessage());
} else {
return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e));
}
}

if (connection == null) {
return new InterpreterResult(Code.ERROR, "User's connection not found.");
}
statement = connection.createStatement();

if (interpreterName != null && interpreterName.startsWith("spark_rca_")) {
statement.setQueryTimeout(10800); // 10800 seconds = 3 hours
Expand Down Expand Up @@ -890,14 +1036,7 @@ private InterpreterResult executeSql(String sql,
Boolean.parseBoolean(getProperty("hive.log.display", "true")), this);
}

String userName = getUser(context);
String sqlToValidate = sqlToExecute
.replace("\n", " ")
.replace("\r", " ")
.replace("\t", " ");
ValidationRequest request = new ValidationRequest(sqlToValidate, userName, interpreterName, sqlToExecute);
try {
ValidationResponse response = sendValidationRequest(request);
if (response.isPreSubmitFail()) {
if(response.getVersion() == "v1") {
String outputMessage = response.getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* UserConfigurations for JDBC impersonation.
*/
public class JDBCUserConfigurations {
private final Map<String, Statement> paragraphIdStatementMap;
private PoolingDriver poolingDriver;
private final Set<String> registeredPools;
private Properties properties;
private Boolean isSuccessful;

public JDBCUserConfigurations() {
paragraphIdStatementMap = new HashMap<>();
registeredPools = ConcurrentHashMap.newKeySet();
}

public void initStatementMap() throws SQLException {
Expand All @@ -45,6 +49,7 @@ public void initStatementMap() throws SQLException {

public void initConnectionPoolMap() throws SQLException {
this.poolingDriver = null;
this.registeredPools.clear();
this.isSuccessful = null;
}

Expand Down Expand Up @@ -83,8 +88,31 @@ public void saveDBDriverPool(PoolingDriver driver) throws SQLException {
this.isSuccessful = false;
}

public void saveDBDriverPool(PoolingDriver driver, String poolName) throws SQLException {
this.poolingDriver = driver;
this.registeredPools.add(poolName);
this.isSuccessful = false;
}

/**
* Returns the current PoolingDriver without removing it.
* Use this when you need to close a single named pool without discarding all pool state.
*/
public PoolingDriver getPoolingDriver() {
return this.poolingDriver;
}

/**
* Removes a single pool name from the registered set.
* Does NOT clear the PoolingDriver reference — other pools remain accessible.
*/
public void removePoolName(String poolName) {
this.registeredPools.remove(poolName);
}

public PoolingDriver removeDBDriverPool() throws SQLException {
this.isSuccessful = null;
this.registeredPools.clear();
PoolingDriver tmp = poolingDriver;
this.poolingDriver = null;
return tmp;
Expand All @@ -94,6 +122,10 @@ public boolean isConnectionInDBDriverPool() {
return this.poolingDriver != null;
}

public boolean isConnectionInDBDriverPool(String poolName) {
return this.poolingDriver != null && this.registeredPools.contains(poolName);
}

public void setConnectionInDBDriverPoolSuccessful() {
this.isSuccessful = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ public class ValidationRequest {
@SerializedName("raw_query_text")
private String rawQueryText;

public ValidationRequest(String queryText, String user, String interpreterName, String rawQueryText) {
@SerializedName("raw_jdbc_url")
private String rawJdbcUrl;

public ValidationRequest(String queryText, String user, String interpreterName, String rawQueryText, String rawJdbcUrl) {
this.queryText = queryText;
this.user = user;
this.interpreterName = interpreterName;
this.rawQueryText = rawQueryText;
this.rawJdbcUrl = rawJdbcUrl;
}

public String toJson() {
Expand Down
Loading
Loading