diff --git a/CHANGELOG.md b/CHANGELOG.md index edb6895a..9f059a81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Changed + +- mTLS connections no longer disconnect when the certificate refresh command exits with a non-zero code + ## 0.8.5 - 2026-02-03 ### Added diff --git a/gradle.properties b/gradle.properties index 37aba2ea..eb319e2e 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,3 +1,3 @@ -version=0.8.5 +version=0.8.6 group=com.coder.toolbox name=coder-toolbox \ No newline at end of file diff --git a/src/main/kotlin/com/coder/toolbox/sdk/CoderRestClient.kt b/src/main/kotlin/com/coder/toolbox/sdk/CoderRestClient.kt index d96e82ad..31255d99 100644 --- a/src/main/kotlin/com/coder/toolbox/sdk/CoderRestClient.kt +++ b/src/main/kotlin/com/coder/toolbox/sdk/CoderRestClient.kt @@ -368,16 +368,12 @@ open class CoderRestClient( return@withContext try { val result = ProcessExecutor() .command(command.split(" ").toList()) - .exitValueNormal() + .exitValueAny() .readOutput(true) .execute() - - if (result.exitValue == 0) { + if (tlsContext.reload()) { context.logger.info("Certificate refresh successful. Reloading TLS and evicting pool.") - tlsContext.reload() - - // This is the "Magic Fix": - // It forces OkHttp to close the broken HTTP/2 connection. + // forces OkHttp to close the broken HTTP/2 connection. httpClient.connectionPool.evictAll() return@withContext true } else { diff --git a/src/main/kotlin/com/coder/toolbox/util/Hash.kt b/src/main/kotlin/com/coder/toolbox/util/Hash.kt index e23a11d7..ed5715f3 100644 --- a/src/main/kotlin/com/coder/toolbox/util/Hash.kt +++ b/src/main/kotlin/com/coder/toolbox/util/Hash.kt @@ -5,17 +5,19 @@ import java.io.InputStream import java.security.DigestInputStream import java.security.MessageDigest +private const val BUFFER_SIZE = 8192 + fun ByteArray.toHex() = joinToString(separator = "") { byte -> "%02x".format(byte) } /** * Return the SHA-1 for the provided stream. */ -@Suppress("ControlFlowWithEmptyBody") fun sha1(stream: InputStream): String { val md = MessageDigest.getInstance("SHA-1") - val dis = DigestInputStream(BufferedInputStream(stream), md) - stream.use { - while (dis.read() != -1) { + DigestInputStream(BufferedInputStream(stream), md).use { dis -> + val buffer = ByteArray(BUFFER_SIZE) + while (dis.read(buffer) != -1) { + // Read until EOF } } return md.digest().toHex() diff --git a/src/main/kotlin/com/coder/toolbox/util/TLS.kt b/src/main/kotlin/com/coder/toolbox/util/TLS.kt index 101370d2..bca20619 100644 --- a/src/main/kotlin/com/coder/toolbox/util/TLS.kt +++ b/src/main/kotlin/com/coder/toolbox/util/TLS.kt @@ -284,16 +284,29 @@ class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : class ReloadableX509TrustManager( private val caPath: String?, ) : X509TrustManager { + private var lastHash: String? = null + @Volatile private var delegate: X509TrustManager = loadTrustManager() private fun loadTrustManager(): X509TrustManager { + if (!caPath.isNullOrBlank()) { + lastHash = sha1(FileInputStream(expand(caPath))) + } val trustManagers = coderTrustManagers(caPath) return trustManagers.first { it is X509TrustManager } as X509TrustManager } - fun reload() { - delegate = loadTrustManager() + fun reload(): Boolean { + if (caPath.isNullOrBlank()) { + return false + } + val newHash = sha1(FileInputStream(expand(caPath))) + if (lastHash != newHash) { + delegate = loadTrustManager() + return true + } + return false } override fun checkClientTrusted(chain: Array?, authType: String?) { @@ -312,15 +325,31 @@ class ReloadableX509TrustManager( class ReloadableSSLSocketFactory( private val settings: ReadOnlyTLSSettings, ) : SSLSocketFactory() { + private var lastCertHash: String? = null + private var lastKeyHash: String? = null + @Volatile private var delegate: SSLSocketFactory = loadSocketFactory() private fun loadSocketFactory(): SSLSocketFactory { + if (!settings.certPath.isNullOrBlank() && !settings.keyPath.isNullOrBlank()) { + lastCertHash = sha1(FileInputStream(expand(settings.certPath!!))) + lastKeyHash = sha1(FileInputStream(expand(settings.keyPath!!))) + } return coderSocketFactory(settings) } - fun reload() { - delegate = loadSocketFactory() + fun reload(): Boolean { + if (settings.certPath.isNullOrBlank() || settings.keyPath.isNullOrBlank()) { + return false + } + val newCertHash = sha1(FileInputStream(expand(settings.certPath!!))) + val newKeyHash = sha1(FileInputStream(expand(settings.keyPath!!))) + if (lastCertHash != newCertHash || lastKeyHash != newKeyHash) { + delegate = loadSocketFactory() + return true + } + return false } override fun getDefaultCipherSuites(): Array = delegate.defaultCipherSuites @@ -349,8 +378,9 @@ class ReloadableTlsContext( val sslSocketFactory = ReloadableSSLSocketFactory(settings) val trustManager = ReloadableX509TrustManager(settings.caPath) - fun reload() { - sslSocketFactory.reload() - trustManager.reload() + fun reload(): Boolean { + val socketFactoryReloaded = sslSocketFactory.reload() + val trustManagerReloaded = trustManager.reload() + return socketFactoryReloaded || trustManagerReloaded } }