diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/Bookkeeper.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/Bookkeeper.scala index 79c08636..cbcbf95f 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/Bookkeeper.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/Bookkeeper.scala @@ -81,7 +81,6 @@ object Bookkeeper { val dbOpt = if (hasBookkeepingJdbc) { val jdbcConfig = bookkeepingConfig.bookkeepingJdbcConfig.get val syncDb = PramenDb(jdbcConfig) - syncDb.setupDatabase() Option(syncDb) } else None diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/BookkeeperJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/BookkeeperJdbc.scala index 48a91019..cada2d40 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/BookkeeperJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/BookkeeperJdbc.scala @@ -146,6 +146,7 @@ class BookkeeperJdbc(db: Database, profile: JdbcProfile, batchId: Long) extends val record = BookkeepingRecord(table, dateStr, dateStr, dateStr, inputRecordCount, outputRecordCount, recordsAppended, jobStarted, jobFinished, Option(batchId)) try { + SlickUtils.ensureDbConnected(db) db.run( BookkeepingRecords.records += record ).execute() @@ -230,6 +231,7 @@ class BookkeeperJdbc(db: Database, profile: JdbcProfile, batchId: Long) extends val infoDateStr = infoDate.toString try { + SlickUtils.ensureDbConnected(db) db.run( SchemaRecords.records.filter(t => t.pramenTableName === table && t.infoDate === infoDateStr).delete ).execute() diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/OffsetManagerJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/OffsetManagerJdbc.scala index 4cbbace4..685b72f4 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/OffsetManagerJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/bookkeeper/OffsetManagerJdbc.scala @@ -73,6 +73,7 @@ class OffsetManagerJdbc(db: Database, batchId: Long) extends OffsetManager { val record = OffsetRecord(table, infoDate.toString, offsetType.dataTypeString, "", "", batchId, createdAt.toEpochMilli, None) + SlickUtils.ensureDbConnected(db) db.run( OffsetRecords.records += record ).execute() @@ -83,6 +84,7 @@ class OffsetManagerJdbc(db: Database, batchId: Long) extends OffsetManager { override def commitOffsets(request: DataOffsetRequest, minOffset: OffsetValue, maxOffset: OffsetValue): Unit = { val committedAt = Instant.now().toEpochMilli + SlickUtils.ensureDbConnected(db) db.run( OffsetRecords.records .filter(r => r.pramenTableName === request.tableName && r.infoDate === request.infoDate.toString && r.createdAt === request.createdAt.toEpochMilli) @@ -98,6 +100,7 @@ class OffsetManagerJdbc(db: Database, batchId: Long) extends OffsetManager { val committedAt = Instant.now().toEpochMilli + SlickUtils.ensureDbConnected(db) db.run( OffsetRecords.records .filter(r => r.pramenTableName === request.tableName && r.infoDate === request.infoDate.toString && r.createdAt === request.createdAt.toEpochMilli) @@ -121,6 +124,7 @@ class OffsetManagerJdbc(db: Database, batchId: Long) extends OffsetManager { OffsetRecord(req.table, req.infoDate.toString, req.minOffset.dataType.dataTypeString, req.minOffset.valueString, req.maxOffset.valueString, batchId, req.createdAt.toEpochMilli, Some(committedAtMilli)) } + SlickUtils.ensureDbConnected(db) db.run( OffsetRecords.records ++= records ).execute() @@ -137,6 +141,7 @@ class OffsetManagerJdbc(db: Database, batchId: Long) extends OffsetManager { } override def rollbackOffsets(request: DataOffsetRequest): Unit = { + SlickUtils.ensureDbConnected(db) db.run( OffsetRecords.records .filter(r => r.pramenTableName === request.tableName && r.infoDate === request.infoDate.toString && r.createdAt === request.createdAt.toEpochMilli) diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/journal/JournalJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/journal/JournalJdbc.scala index 80048d70..ea0b562c 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/journal/JournalJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/journal/JournalJdbc.scala @@ -60,6 +60,7 @@ class JournalJdbc(db: Database) extends Journal { Option(entry.batchId)) try { + SlickUtils.ensureDbConnected(db) db.run( JournalTasks.journalTasks += journalTask ).execute() diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/lock/TokenLockJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/lock/TokenLockJdbc.scala index 5ec3b7d9..e125993d 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/lock/TokenLockJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/lock/TokenLockJdbc.scala @@ -36,6 +36,8 @@ class TokenLockJdbc(token: String, db: Database) extends TokenLockBase(token) { /** Invoked from a synchronized block. */ override def tryAcquireGuardLock(retries: Int = 3, thisTry: Int = 0): Boolean = { + SlickUtils.ensureDbConnected(db) + def tryAcquireExistingTicket(): Boolean = { val ticket = getTicket diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbc.scala index 340e6890..a7dc90c2 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbc.scala @@ -57,6 +57,7 @@ class MetadataManagerJdbc(db: Database) extends MetadataManagerBase(true) { val record = MetadataRecord(tableName, infoDate.toString, key, metadata.value, metadata.lastUpdated.getEpochSecond) try { + SlickUtils.ensureDbConnected(db) db.run( MetadataRecords.records .filter(r => r.pramenTableName === tableName && r.infoDate === infoDate.toString && r.key === key) @@ -76,6 +77,7 @@ class MetadataManagerJdbc(db: Database) extends MetadataManagerBase(true) { .filter(r => r.pramenTableName === tableName && r.infoDate === infoDate.toString && r.key === key) try { + SlickUtils.ensureDbConnected(db) db.run(query.delete).execute() } catch { case NonFatal(ex) => throw new RuntimeException(s"Unable to delete from the metadata table.", ex) @@ -87,6 +89,7 @@ class MetadataManagerJdbc(db: Database) extends MetadataManagerBase(true) { .filter(r => r.pramenTableName === tableName && r.infoDate === infoDate.toString) try { + SlickUtils.ensureDbConnected(db) db.run(query.delete).execute() } catch { case NonFatal(ex) => throw new RuntimeException(s"Unable to delete from the metadata table.", ex) diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/PramenDb.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/PramenDb.scala index ffdd4ede..5f6a2857 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/PramenDb.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/PramenDb.scala @@ -18,7 +18,7 @@ package za.co.absa.pramen.core.rdb import org.slf4j.LoggerFactory import slick.jdbc.JdbcBackend.Database -import slick.jdbc.JdbcProfile +import slick.jdbc.{JdbcBackend, JdbcProfile} import slick.util.AsyncExecutor import za.co.absa.pramen.core.bookkeeper.model.{BookkeepingRecords, MetadataRecords, OffsetRecords, SchemaRecords} import za.co.absa.pramen.core.journal.model.JournalTasks @@ -26,6 +26,7 @@ import za.co.absa.pramen.core.lock.model.LockTickets import za.co.absa.pramen.core.rdb.PramenDb.MODEL_VERSION import za.co.absa.pramen.core.reader.JdbcUrlSelector import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.{AlgorithmUtils, UsingUtils} import java.sql.Connection import scala.util.Try @@ -33,7 +34,6 @@ import scala.util.control.NonFatal class PramenDb(val jdbcConfig: JdbcConfig, val activeUrl: String, - val jdbcConnection: Connection, val slickDb: Database, val profile: JdbcProfile) extends AutoCloseable { def db: Database = slickDb @@ -43,22 +43,22 @@ class PramenDb(val jdbcConfig: JdbcConfig, private val log = LoggerFactory.getLogger(this.getClass) - val rdb: Rdb = new RdbJdbc(jdbcConnection) - - def setupDatabase(): Unit = { + private def setupDatabase(jdbcConnection: Connection): Unit = { // Explicitly set auto-commit to true, overriding any user JDBC settings or PostgreSQL defaults Try(jdbcConnection.setAutoCommit(true)).recover { case NonFatal(e) => log.warn(s"Unable to set autoCommit=true for the bookkeeping database that uses the driver: ${jdbcConfig.driver}.") } - val dbVersion = rdb.getVersion() - if (dbVersion < MODEL_VERSION) { - initDatabase(dbVersion) - rdb.setVersion(MODEL_VERSION) + UsingUtils.using(new RdbJdbc(jdbcConnection)) { rdb => + val dbVersion = rdb.getVersion() + if (dbVersion < MODEL_VERSION) { + initDatabase(dbVersion) + rdb.setVersion(MODEL_VERSION) + } } } - def initDatabase(dbVersion: Int): Unit = { + private def initDatabase(dbVersion: Int): Unit = { log.warn(s"Initializing new database at $activeUrl") if (dbVersion < 1) { initTable(LockTickets.lockTickets.schema) @@ -103,7 +103,7 @@ class PramenDb(val jdbcConfig: JdbcConfig, } } - def initTable(schema: profile.SchemaDescription): Unit = { + private def initTable(schema: profile.SchemaDescription): Unit = { try { db.run(DBIO.seq( schema.createIfNotExists @@ -115,7 +115,7 @@ class PramenDb(val jdbcConfig: JdbcConfig, } } - def addColumn(table: String, columnName: String, columnType: String): Unit = { + private def addColumn(table: String, columnName: String, columnType: String): Unit = { try { val quotedTable = s""""$table"""" val quotedColumnName = s""""$columnName"""" @@ -130,19 +130,33 @@ class PramenDb(val jdbcConfig: JdbcConfig, override def close(): Unit = { - jdbcConnection.close() - slickDb.close() + try { + slickDb.close() + } catch { + case NonFatal(ex) => + log.warn("Error closing the Pramen RDB database connection.", ex) + } + } } object PramenDb { + private val log = LoggerFactory.getLogger(this.getClass) + val MODEL_VERSION = 9 val DEFAULT_RETRIES = 3 + val BACKOFF_MIN_MS = 10000 + val BACKOFF_MAX_MS = 60000 def apply(jdbcConfig: JdbcConfig): PramenDb = { - val (url, conn, database, profile) = openDb(jdbcConfig) + val (url, connection) = getConnection(jdbcConfig) - new PramenDb(jdbcConfig, url, conn, database, profile) + UsingUtils.using(connection) { conn => + val (database, profile) = openDb(jdbcConfig, url) + val pramenDb = new PramenDb(jdbcConfig, url, database, profile) + pramenDb.setupDatabase(conn) + pramenDb + } } def getProfile(driver: String): JdbcProfile = { @@ -159,20 +173,30 @@ object PramenDb { } } - private def openDb(jdbcConfig: JdbcConfig): (String, Connection, Database, JdbcProfile) = { + def getConnection(jdbcConfig: JdbcConfig): (String, Connection) = { + val numberOfAttempts = jdbcConfig.retries.getOrElse(DEFAULT_RETRIES) + val selector = JdbcUrlSelector(jdbcConfig) + val (conn, url) = selector.getWorkingConnection(numberOfAttempts) + + (url, conn) + } + + def openDb(jdbcConfig: JdbcConfig, workingUrl: String): (Database, JdbcProfile) = { + val numberOfAttempts = jdbcConfig.retries.getOrElse(DEFAULT_RETRIES) val selector = JdbcUrlSelector(jdbcConfig) - val (conn, url) = selector.getWorkingConnection(DEFAULT_RETRIES) val prop = selector.getProperties val slickProfile = getProfile(jdbcConfig.driver) - val database = jdbcConfig.user match { - case Some(user) => Database.forURL(url = url, driver = jdbcConfig.driver, user = user, password = jdbcConfig.password.getOrElse(""), prop = prop, executor = AsyncExecutor("Rdb", 2, 10)) - case None => Database.forURL(url = url, driver = jdbcConfig.driver, prop = prop, executor = AsyncExecutor("Rdb", 2, 10)) + var database: JdbcBackend.DatabaseDef = null + AlgorithmUtils.actionWithRetry(numberOfAttempts, log, BACKOFF_MIN_MS, BACKOFF_MAX_MS) { + database = jdbcConfig.user match { + case Some(user) => Database.forURL(url = workingUrl, driver = jdbcConfig.driver, user = user, password = jdbcConfig.password.getOrElse(""), prop = prop, executor = AsyncExecutor("Rdb", 2, 10)) + case None => Database.forURL(url = workingUrl, driver = jdbcConfig.driver, prop = prop, executor = AsyncExecutor("Rdb", 2, 10)) + } } - (url, conn, database, slickProfile) + (database, slickProfile) } - } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/RdbJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/RdbJdbc.scala index cc8ee676..6d0aa55e 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/RdbJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/rdb/RdbJdbc.scala @@ -17,16 +17,16 @@ package za.co.absa.pramen.core.rdb import org.slf4j.LoggerFactory +import za.co.absa.pramen.core.rdb.PramenDb.DEFAULT_RETRIES import za.co.absa.pramen.core.rdb.RdbJdbc.dbVersionTableName +import za.co.absa.pramen.core.reader.JdbcUrlSelector +import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils import java.sql.{Connection, SQLException} import scala.util.control.NonFatal -object RdbJdbc { - val dbVersionTableName = "db_version" -} - -class RdbJdbc(connection: Connection) extends Rdb{ +class RdbJdbc(val connection: Connection) extends AutoCloseable with Rdb{ private val log = LoggerFactory.getLogger(this.getClass) override def getVersion(): Int = { @@ -61,9 +61,9 @@ class RdbJdbc(connection: Connection) extends Rdb{ } override def executeDDL(ddl: String): Unit = { - val statement = connection.createStatement() - statement.execute(ddl) - statement.close() + UsingUtils.using(connection.createStatement()) { statement => + statement.execute(ddl) + } } private def getDbVersion(): Int = { @@ -80,4 +80,17 @@ class RdbJdbc(connection: Connection) extends Rdb{ dbVersion } + override def close(): Unit = if (!connection.isClosed) connection.close() +} + +object RdbJdbc { + val dbVersionTableName = "db_version" + + def apply(jdbcConfig: JdbcConfig): RdbJdbc = { + val numberOfAttempts = jdbcConfig.retries.getOrElse(DEFAULT_RETRIES) + val selector = JdbcUrlSelector(jdbcConfig) + val (conn, _) = selector.getWorkingConnection(numberOfAttempts) + + new RdbJdbc(conn) + } } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/JdbcUrlSelectorImpl.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/JdbcUrlSelectorImpl.scala index 5ac01aa2..b4b6445c 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/JdbcUrlSelectorImpl.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/JdbcUrlSelectorImpl.scala @@ -27,6 +27,9 @@ import scala.util.{Failure, Random, Success, Try} class JdbcUrlSelectorImpl(val jdbcConfig: JdbcConfig) extends JdbcUrlSelector{ private val log = LoggerFactory.getLogger(this.getClass) + + private val BACKOFF_MIN_S = 1 + private val BACKOFF_MAX_S = 10 private val allUrls = (jdbcConfig.primaryUrl ++ jdbcConfig.fallbackUrls).toSeq private val numberOfUrls = allUrls.size private var urlPool = allUrls @@ -102,7 +105,6 @@ class JdbcUrlSelectorImpl(val jdbcConfig: JdbcConfig) extends JdbcUrlSelector{ @throws[SQLException] def getWorkingConnection(retriesLeft: Int): (Connection, String) = { val currentUrl = getUrl - Try { JdbcNativeUtils.getJdbcConnection(jdbcConfig, currentUrl) } match { @@ -110,7 +112,9 @@ class JdbcUrlSelectorImpl(val jdbcConfig: JdbcConfig) extends JdbcUrlSelector{ case Failure(ex) => if (retriesLeft > 1) { val newUrl = getNextUrl - log.error(s"JDBC connection error for $currentUrl. Retries left: ${retriesLeft - 1}. Retrying...", ex) + val backoffS = Random.nextInt(BACKOFF_MAX_S - BACKOFF_MIN_S) + BACKOFF_MIN_S + log.error(s"JDBC connection error for $currentUrl. Retries left: ${retriesLeft - 1}. Retrying... in $backoffS seconds", ex) + Thread.sleep(backoffS * 1000) log.info(s"Trying URL: $newUrl") getWorkingConnection(retriesLeft - 1) } else { diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/AlgorithmUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/AlgorithmUtils.scala index cf2ea6be..984fe8d4 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/AlgorithmUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/AlgorithmUtils.scala @@ -21,6 +21,7 @@ import org.slf4j.Logger import java.time.{Duration, Instant} import scala.annotation.tailrec import scala.collection.mutable +import scala.util.Random object AlgorithmUtils { /** Finds which strings are encountered multiple times (case insensitive). */ @@ -70,7 +71,7 @@ object AlgorithmUtils { } @tailrec - final def actionWithRetry(attempts: Int, log: Logger)(action: => Unit): Unit = { + final def actionWithRetry(attempts: Int, log: Logger, backoffMinMs: Int = 0, backoffMaxMs: Int = 0)(action: => Unit): Unit = { def getErrorMessage(ex: Throwable): String = { if (ex.getCause == null) { ex.getMessage @@ -88,8 +89,16 @@ object AlgorithmUtils { if (attemptsLeft < 1) { throw ex } else { - log.warn(s"Attempt failed: ${getErrorMessage(ex)}. Attempts left: $attemptsLeft. Retrying...") - actionWithRetry(attemptsLeft, log)(action) + if (backoffMaxMs > backoffMinMs && backoffMinMs > 0) { + val backoffMs = Random.nextInt(backoffMaxMs - backoffMinMs) + backoffMinMs + val backoffS = backoffMs / 1000 + log.warn(s"Attempt failed: ${getErrorMessage(ex)}. Attempts left: $attemptsLeft. Retrying in $backoffS seconds...") + Thread.sleep(backoffMs) + } else { + log.warn(s"Attempt failed: ${getErrorMessage(ex)}. Attempts left: $attemptsLeft. Retrying...") + } + + actionWithRetry(attemptsLeft, log, backoffMinMs, backoffMaxMs)(action) } } } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/FutureImplicits.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/FutureImplicits.scala index a74b1b65..baa9b14f 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/FutureImplicits.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/FutureImplicits.scala @@ -17,12 +17,11 @@ package za.co.absa.pramen.core.utils import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.{Await, Future} object FutureImplicits { - private val executionTimeout = Duration(300, TimeUnit.SECONDS) + val executionTimeout: FiniteDuration = Duration(300, TimeUnit.SECONDS) implicit class FutureExecutor[T](future: Future[T]) { def execute(): T = Await.result(future, executionTimeout) diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SlickUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SlickUtils.scala index 74e8701a..2f404bbf 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SlickUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SlickUtils.scala @@ -20,6 +20,7 @@ import org.slf4j.LoggerFactory import slick.dbio.Effect import slick.jdbc.PostgresProfile.api._ import slick.sql.SqlAction +import za.co.absa.pramen.core.rdb.PramenDb import java.time.{Duration, Instant} import scala.util.control.NonFatal @@ -43,6 +44,8 @@ object SlickUtils { * @return The result of the query */ def executeQuery[E, U](db: Database, query: Query[E, U, Seq]): Seq[U] = { + ensureDbConnected(db) + val action = query.result val sql = action.statements.mkString("; ") @@ -75,10 +78,13 @@ object SlickUtils { * @return The result of the action */ def executeAction[R, E <: Effect](db: Database, action: SqlAction[R, NoStream, E]): R = { + ensureDbConnected(db) + val sql = action.statements.mkString("; ") try { val start = Instant.now + ensureDbConnected(db) val result = db.run(action).execute() val finish = Instant.now @@ -106,6 +112,8 @@ object SlickUtils { * @return The result of the query */ def executeCount(db: Database, rep: Rep[Int]): Int = { + ensureDbConnected(db) + val action = rep.result val sql = action.statements.mkString("; ") @@ -138,6 +146,8 @@ object SlickUtils { * @return The result of the query */ def executeMaxString(db: Database, rep: Rep[Option[String]]): Option[String] = { + ensureDbConnected(db) + val action = rep.result val sql = action.statements.mkString("; ") @@ -158,4 +168,23 @@ object SlickUtils { case NonFatal(ex) => throw new RuntimeException(s"Error executing an SQL query: $sql", ex) } } + + /** + * Ensures that the database connection is valid and ready for use. + * If the connection is not valid, an exception is thrown. + * The method retries the connection check according to the retry logic. + * + * @param db The database instance to verify the connection for. + */ + def ensureDbConnected(db: Database): Unit = { + val check = SimpleDBIO { ctx => + val conn = ctx.connection + if (!conn.isValid(FutureImplicits.executionTimeout.toSeconds.toInt)) + throw new RuntimeException("Connection not valid") + } + + AlgorithmUtils.actionWithRetry(PramenDb.DEFAULT_RETRIES, log, PramenDb.BACKOFF_MIN_MS, PramenDb.BACKOFF_MAX_MS) { + db.run(check).execute() + } + } } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/UsingUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/UsingUtils.scala new file mode 100644 index 00000000..09494e29 --- /dev/null +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/UsingUtils.scala @@ -0,0 +1,66 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.utils + +import scala.util.control.NonFatal + +object UsingUtils { + /** + * Executes the given action with a resource that implements the AutoCloseable interface, ensuring + * proper closure of the resource. Any exception that occurs during the action or resource closure + * is handled appropriately, with suppressed exceptions added where relevant. Null resources are not supported. + * + * @param resource a lazily evaluated resource that implements AutoCloseable + * @param action a function to be executed using the provided resource + * @tparam T the type of the resource, which must extend AutoCloseable + * @throws Throwable if either the action or resource closure fails. If both fail, the action's exception + * is thrown with the closure's exception added as suppressed + */ + def using[T <: AutoCloseable,U](resource: => T)(action: T => U): U = { + var thrownException: Option[Throwable] = None + var suppressedException: Option[Throwable] = None + val openedResource = resource + + if (openedResource == null) { + throw new IllegalArgumentException("Resource must not be null") + } + + val result = try { + Option(action(openedResource)) + } catch { + case NonFatal(ex) => + thrownException = Option(ex) + None + } finally + if (openedResource != null) { + try + openedResource.close() + catch { + case NonFatal(ex) => suppressedException = Option(ex) + } + } + + (thrownException, suppressedException) match { + case (Some(thrown), Some(suppressed)) => + thrown.addSuppressed(suppressed) + throw thrown + case (Some(thrown), None) => throw thrown + case (None, Some(suppressed)) => throw suppressed + case (None, None) => result.getOrElse(throw new IllegalArgumentException("Action returned null")) + } + } +} diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineJdbcLongSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineJdbcLongSuite.scala index 0435ddc6..ad63b5cb 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineJdbcLongSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineJdbcLongSuite.scala @@ -23,12 +23,12 @@ import org.scalatest.wordspec.AnyWordSpec import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import za.co.absa.pramen.core.base.SparkTestBase import za.co.absa.pramen.core.fixtures.{RelationalDbFixture, TempDirFixture, TextComparisonFixture} -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.JdbcUrlSelectorImpl import za.co.absa.pramen.core.reader.model.JdbcConfig import za.co.absa.pramen.core.runner.AppRunner import za.co.absa.pramen.core.samples.RdbExampleTable -import za.co.absa.pramen.core.utils.{JdbcNativeUtils, ResourceUtils} +import za.co.absa.pramen.core.utils.{JdbcNativeUtils, ResourceUtils, UsingUtils} import java.sql.Date import java.time.LocalDate @@ -42,7 +42,7 @@ class IncrementalPipelineJdbcLongSuite extends AnyWordSpec with TextComparisonFixture { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Some(user), Some(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ private val infoDate = LocalDate.of(2021, 2, 18) @@ -50,13 +50,17 @@ class IncrementalPipelineJdbcLongSuite extends AnyWordSpec private val INFO_DATE_COLUMN = "pramen_info_date" before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) + RdbExampleTable.IncrementalTable.initTable(getConnection) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineLongFixture.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineLongFixture.scala index 0c47d24b..a28dc58a 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineLongFixture.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/integration/IncrementalPipelineLongFixture.scala @@ -28,11 +28,11 @@ import za.co.absa.pramen.api.offset.OffsetType import za.co.absa.pramen.core.base.SparkTestBase import za.co.absa.pramen.core.bookkeeper.OffsetManagerJdbc import za.co.absa.pramen.core.fixtures.{RelationalDbFixture, TempDirFixture, TextComparisonFixture} -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.JdbcUrlSelectorImpl import za.co.absa.pramen.core.reader.model.JdbcConfig import za.co.absa.pramen.core.runner.AppRunner -import za.co.absa.pramen.core.utils.{FsUtils, JdbcNativeUtils, ResourceUtils} +import za.co.absa.pramen.core.utils.{FsUtils, JdbcNativeUtils, ResourceUtils, UsingUtils} import java.sql.Date import java.time.LocalDate @@ -46,15 +46,18 @@ class IncrementalPipelineLongFixture extends AnyWordSpec with TextComparisonFixture { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Some(user), Some(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbcSuite.scala index 213273ce..fb475bf2 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbcSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/metadata/MetadataManagerJdbcSuite.scala @@ -20,24 +20,28 @@ import org.scalatest.wordspec.AnyWordSpec import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import za.co.absa.pramen.api.MetadataValue import za.co.absa.pramen.core.fixtures.RelationalDbFixture -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils import java.time.{LocalDate, ZoneOffset} class MetadataManagerJdbcSuite extends AnyWordSpec with RelationalDbFixture with BeforeAndAfter with BeforeAndAfterAll { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Some(user), Some(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ private val infoDate = LocalDate.of(2021, 2, 18) private val exampleInstant = infoDate.atStartOfDay().toInstant(ZoneOffset.UTC) before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } @@ -164,7 +168,7 @@ class MetadataManagerJdbcSuite extends AnyWordSpec with RelationalDbFixture with assert(ex.getMessage.contains("Unable to delete from the metadata table.")) } - "throw an exception on connection errors when deleting metadata from a partision" in { + "throw an exception on connection errors when deleting metadata from a partition" in { val metadata = new MetadataManagerJdbc(null) val ex = intercept[RuntimeException] { diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/mocks/AutoCloseableSpy.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/mocks/AutoCloseableSpy.scala new file mode 100644 index 00000000..57e03a6a --- /dev/null +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/mocks/AutoCloseableSpy.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.mocks + +class AutoCloseableSpy(failCreate: Boolean = false, failAction: Boolean = false, failClose: Boolean = false) extends AutoCloseable { + var actionCallCount: Int = 0 + var closeCallCount: Int = 0 + + if (failCreate) { + throw new RuntimeException("Failed to create resource") + } + + def dummyAction(): Unit = { + actionCallCount += 1 + if (failAction) { + throw new RuntimeException("Failed during action") + } + } + + override def close(): Unit = { + closeCallCount += 1 + if (failClose) { + throw new RuntimeException("Failed to close resource") + } + } +} diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperJdbcSuite.scala index f06672d9..b0ae67c8 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperJdbcSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperJdbcSuite.scala @@ -19,21 +19,25 @@ package za.co.absa.pramen.core.tests.bookkeeper import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import za.co.absa.pramen.core.bookkeeper.{Bookkeeper, BookkeeperJdbc} import za.co.absa.pramen.core.fixtures.RelationalDbFixture -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils class BookkeeperJdbcSuite extends BookkeeperCommonSuite with RelationalDbFixture with BeforeAndAfter with BeforeAndAfterAll { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Some(user), Some(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperSuite.scala index 41adb348..e9c2fa31 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/BookkeeperSuite.scala @@ -26,8 +26,9 @@ import za.co.absa.pramen.core.fixtures.{MongoDbFixture, RelationalDbFixture, Tem import za.co.absa.pramen.core.journal._ import za.co.absa.pramen.core.lock.{TokenLockFactoryAllow, TokenLockFactoryHadoopPath, TokenLockFactoryJdbc, TokenLockFactoryMongoDb} import za.co.absa.pramen.core.metadata.{MetadataManagerJdbc, MetadataManagerNull} -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils import za.co.absa.pramen.core.{BookkeepingConfigFactory, RuntimeConfigFactory} import java.nio.file.Paths @@ -43,11 +44,14 @@ class BookkeeperSuite extends AnyWordSpec import za.co.absa.pramen.core.bookkeeper.BookkeeperMongoDb._ val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Option(user), Option(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) if (db != null) { if (db.doesCollectionExists(collectionName)) { @@ -59,6 +63,10 @@ class BookkeeperSuite extends AnyWordSpec } } + override def afterAll(): Unit = { + if (pramenDb != null) pramenDb.close() + super.afterAll() + } val runtimeConfig: RuntimeConfig = RuntimeConfigFactory.getDummyRuntimeConfig( useLocks = true diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/OffsetManagerJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/OffsetManagerJdbcSuite.scala index 0ec05cca..7dc0e0c8 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/OffsetManagerJdbcSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/bookkeeper/OffsetManagerJdbcSuite.scala @@ -22,24 +22,28 @@ import za.co.absa.pramen.api.offset.DataOffset.{CommittedOffset, UncommittedOffs import za.co.absa.pramen.api.offset.{OffsetType, OffsetValue} import za.co.absa.pramen.core.bookkeeper.{OffsetManager, OffsetManagerCached, OffsetManagerJdbc} import za.co.absa.pramen.core.fixtures.RelationalDbFixture -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils import java.time.{Instant, LocalDate} class OffsetManagerJdbcSuite extends AnyWordSpec with RelationalDbFixture with BeforeAndAfter with BeforeAndAfterAll { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Some(user), Some(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ private val infoDate = LocalDate.of(2023, 8, 25) before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/journal/JournalJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/journal/JournalJdbcSuite.scala index a2333b5d..a74510c5 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/journal/JournalJdbcSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/journal/JournalJdbcSuite.scala @@ -21,22 +21,26 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import za.co.absa.pramen.core.base.SparkTestBase import za.co.absa.pramen.core.fixtures.RelationalDbFixture import za.co.absa.pramen.core.journal.{Journal, JournalJdbc} -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils class JournalJdbcSuite extends AnyWordSpec with SparkTestBase with BeforeAndAfter with BeforeAndAfterAll with RelationalDbFixture { import TestCases._ val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Option(user), Option(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/lock/TokenLockJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/lock/TokenLockJdbcSuite.scala index 62deb8c7..a23903e9 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/lock/TokenLockJdbcSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/lock/TokenLockJdbcSuite.scala @@ -22,22 +22,26 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import za.co.absa.pramen.api.lock.TokenLock import za.co.absa.pramen.core.fixtures.RelationalDbFixture import za.co.absa.pramen.core.lock.TokenLockJdbc -import za.co.absa.pramen.core.rdb.PramenDb +import za.co.absa.pramen.core.rdb.{PramenDb, RdbJdbc} import za.co.absa.pramen.core.reader.model.JdbcConfig +import za.co.absa.pramen.core.utils.UsingUtils import scala.concurrent.duration._ class TokenLockJdbcSuite extends AnyWordSpec with RelationalDbFixture with BeforeAndAfter with BeforeAndAfterAll { val jdbcConfig: JdbcConfig = JdbcConfig(driver, Some(url), Nil, None, Option(user), Option(password)) - lazy val pramenDb: PramenDb = PramenDb(jdbcConfig) + var pramenDb: PramenDb = _ before { - pramenDb.rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") - pramenDb.setupDatabase() + if (pramenDb != null) pramenDb.close() + UsingUtils.using(RdbJdbc(jdbcConfig)) { rdb => + rdb.executeDDL("DROP SCHEMA PUBLIC CASCADE;") + } + pramenDb = PramenDb(jdbcConfig) } override def afterAll(): Unit = { - pramenDb.close() + if (pramenDb != null) pramenDb.close() super.afterAll() } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/UsingUtilsSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/UsingUtilsSuite.scala new file mode 100644 index 00000000..ce536bbc --- /dev/null +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/UsingUtilsSuite.scala @@ -0,0 +1,245 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.tests.utils + +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.pramen.core.mocks.AutoCloseableSpy +import za.co.absa.pramen.core.utils.UsingUtils + +class UsingUtilsSuite extends AnyWordSpec { + "using with a single resource" should { + "properly close the resource" in { + var resource: AutoCloseableSpy = null + + UsingUtils.using(new AutoCloseableSpy()) { res => + resource = res + res.dummyAction() + } + + assert(resource.actionCallCount == 1) + assert(resource.closeCallCount == 1) + } + + "close resource even if exception occurs" in { + var resource: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy(failAction = true)) { res => + resource = res + res.dummyAction() + } + } catch { + case ex: Throwable => + assert(ex.getMessage.contains("Failed during action")) + exceptionThrown = true + } + + assert(exceptionThrown) + assert(resource.actionCallCount == 1) + assert(resource.closeCallCount == 1) + } + + "throw an exception on null resource" in { + var exceptionThrown = false + + try { + UsingUtils.using(null: AutoCloseableSpy) { res => + res.dummyAction() + } + } catch { + case ex: Throwable => + assert(ex.getMessage.contains("Resource must not be null")) + exceptionThrown = true + } + + assert(exceptionThrown) + } + + "handle exceptions when a resource is created" in { + var exceptionThrown = false + var resource: AutoCloseableSpy = null + + try { + UsingUtils.using(new AutoCloseableSpy(failCreate = true)) { res => + resource = res + res.dummyAction() + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed to create resource")) + } + + assert(exceptionThrown) + assert(resource == null) + } + + "handle exceptions when a resource is closed" in { + var resource: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy(failClose = true)) { res => + resource = res + res.dummyAction() + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed to close resource")) + } + + assert(exceptionThrown) + assert(resource.actionCallCount == 1) + assert(resource.closeCallCount == 1) + } + + "handle exceptions on both action and close" in { + var resource: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy(failClose = true)) { res => + resource = res + res.dummyAction() + throw new RuntimeException("Failed during action") + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed during action")) + val suppressed = ex.getSuppressed + assert(suppressed.length == 1) + assert(suppressed(0).getMessage.contains("Failed to close resource")) + } + + assert(exceptionThrown) + assert(resource.actionCallCount == 1) + assert(resource.closeCallCount == 1) + } + } + + "using with two resources" should { + "properly close both resources" in { + var resource1: AutoCloseableSpy = null + var resource2: AutoCloseableSpy = null + + val result = UsingUtils.using(new AutoCloseableSpy()) { res1 => + resource1 = res1 + UsingUtils.using(new AutoCloseableSpy()) { res2 => + resource2 = res2 + res1.dummyAction() + res2.dummyAction() + 100 + } + } + + assert(result == 100) + assert(resource1.actionCallCount == 1) + assert(resource1.closeCallCount == 1) + assert(resource2.actionCallCount == 1) + assert(resource2.closeCallCount == 1) + } + + "properly close both resources when an inner one throws an exception during action and close" in { + var resource1: AutoCloseableSpy = null + var resource2: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy()) { res1 => + resource1 = res1 + UsingUtils.using(new AutoCloseableSpy(failAction = true, failClose = true)) { res2 => + resource2 = res2 + res1.dummyAction() + res2.dummyAction() + } + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed during action")) + val suppressed = ex.getSuppressed + assert(suppressed.length == 1) + assert(suppressed(0).getMessage.contains("Failed to close resource")) + } + + assert(exceptionThrown) + assert(resource1.actionCallCount == 1) + assert(resource1.closeCallCount == 1) + assert(resource2.actionCallCount == 1) + assert(resource2.closeCallCount == 1) + } + + "properly close both resources when an outer one throws an exception during action and close" in { + var resource1: AutoCloseableSpy = null + var resource2: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy(failAction = true, failClose = true)) { res1 => + resource1 = res1 + UsingUtils.using(new AutoCloseableSpy()) { res2 => + resource2 = res2 + res1.dummyAction() + res2.dummyAction() + } + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed during action")) + val suppressed = ex.getSuppressed + assert(suppressed.length == 1) + assert(suppressed(0).getMessage.contains("Failed to close resource")) + } + + assert(exceptionThrown) + assert(resource1.actionCallCount == 1) + assert(resource1.closeCallCount == 1) + assert(resource2.actionCallCount == 0) + assert(resource2.closeCallCount == 1) + } + + "properly close the outer resource when the inner one fails on create" in { + var resource1: AutoCloseableSpy = null + var resource2: AutoCloseableSpy = null + var exceptionThrown = false + + try { + UsingUtils.using(new AutoCloseableSpy()) { res1 => + resource1 = res1 + UsingUtils.using(new AutoCloseableSpy(failCreate = true)) { res2 => + resource2 = res2 + res1.dummyAction() + res2.dummyAction() + } + } + } catch { + case ex: Throwable => + exceptionThrown = true + assert(ex.getMessage.contains("Failed to create resource")) + } + + assert(exceptionThrown) + assert(resource1.actionCallCount == 0) + assert(resource1.closeCallCount == 1) + assert(resource2 == null) + } + } +}