diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index f3d4ffd..3c06c40 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -40,24 +40,5 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Generate certificates - run: make test-cert - - name: Install MySQL client - run: sudo apt-get update && sudo apt-get install -y mariadb-client - - name: Start MariaDB Docker Container - run: docker compose up -d --build - - name: Wait For MariaDB Docker Container - run: | - for i in {1..30}; do - if docker compose exec -T mariadb mariadb-admin ping -u mariadb -pmariadb --silent; then - exit 0 - fi - sleep 1 - done - docker compose logs mariadb - exit 1 - - name: Run Tests - run: swift test --parallel --enable-code-coverage - - name: Stop MariaDB Docker Container - if: always() - run: docker compose down -v + - name: Run Docker Tests + run: make docker-test diff --git a/Makefile b/Makefile index dbd22a5..8bb4620 100644 --- a/Makefile +++ b/Makefile @@ -41,11 +41,11 @@ fix-headers: curl -s $(baseUrl)/check-swift-headers.sh | bash -s -- --fix -test-cert: +test-certs: rm -rf docker/mariadb/certificates && mkdir -p docker/mariadb/certificates && cd docker/mariadb/certificates && ../scripts/generate-certificates.sh test: swift test --parallel docker-test: - docker build -t feather-database-mysql-tests . -f ./docker/tests/Dockerfile && docker run --rm feather-database-mysql-tests + make test-certs && docker compose up --build --abort-on-container-exit --exit-code-from test test diff --git a/Package.resolved b/Package.resolved index 620e233..92f6067 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,13 +1,13 @@ { - "originHash" : "30e859335fb8af313bcef3aca20a88563f50a778387730a6a7bbd7400afa6838", + "originHash" : "c4fde9af4e1b9d476c7fa52079d42ab958d4df808a7b796426d2466c6661f027", "pins" : [ { "identity" : "feather-database", "kind" : "remoteSourceControl", "location" : "https://github.com/feather-framework/feather-database", "state" : { - "revision" : "55b08d7bd028c7eddbd231efaaaa0d38e78c0b9b", - "version" : "1.0.0-beta.5" + "revision" : "0b862ac3ee31859e2c8a54b06390ee92d1a7c0b4", + "version" : "1.0.0-rc.1" } }, { diff --git a/Package.swift b/Package.swift index 68b4676..8a7c636 100644 --- a/Package.swift +++ b/Package.swift @@ -32,21 +32,38 @@ let package = Package( .visionOS(.v2), ], products: [ + .library(name: "MySQLNIOExtras", targets: ["MySQLNIOExtras"]), .library(name: "FeatherDatabaseMySQL", targets: ["FeatherDatabaseMySQL"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-log", from: "1.6.0"), .package(url: "https://github.com/vapor/mysql-nio", from: "1.8.0"), - .package(url: "https://github.com/feather-framework/feather-database", exact: "1.0.0-beta.5"), + .package(url: "https://github.com/feather-framework/feather-database", exact: "1.0.0-rc.1"), // [docc-plugin-placeholder] ], targets: [ + .target( + name: "MySQLNIOExtras", + dependencies: [ + .product(name: "Logging", package: "swift-log"), + .product(name: "MySQLNIO", package: "mysql-nio"), + ], + swiftSettings: defaultSwiftSettings + ), .target( name: "FeatherDatabaseMySQL", dependencies: [ .product(name: "Logging", package: "swift-log"), .product(name: "MySQLNIO", package: "mysql-nio"), .product(name: "FeatherDatabase", package: "feather-database"), + .target(name: "MySQLNIOExtras"), + ], + swiftSettings: defaultSwiftSettings + ), + .testTarget( + name: "MySQLNIOExtrasTests", + dependencies: [ + .target(name: "MySQLNIOExtras"), ], swiftSettings: defaultSwiftSettings ), diff --git a/README.md b/README.md index 487fdd7..781b505 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ MySQL/MariaDB driver implementation for the abstract [Feather Database](https://github.com/feather-framework/feather-database) Swift API package. [ - ![Release: 1.0.0-beta.5](https://img.shields.io/badge/Release-1%2E0%2E0--beta%2E5-F05138) + ![Release: 1.0.0-rc.1](https://img.shields.io/badge/Release-1%2E0%2E0--rc%2E1-F05138) ]( - https://github.com/feather-framework/feather-database-mysql/releases/tag/1.0.0-beta.5 + https://github.com/feather-framework/feather-database-mysql/releases/tag/1.0.0-rc.1 ) ## Features @@ -36,7 +36,7 @@ MySQL/MariaDB driver implementation for the abstract [Feather Database](https:// Add the dependency to your `Package.swift`: ```swift -.package(url: "https://github.com/feather-framework/feather-database-mysql", exact: "1.0.0-beta.5"), +.package(url: "https://github.com/feather-framework/feather-database-mysql", exact: "1.0.0-rc.1"), ``` Then add `FeatherDatabaseMySQL` to your target dependencies: @@ -122,9 +122,6 @@ catch { } ``` -> [!WARNING] -> This repository is a work in progress, things can break until it reaches v1.0.0. - ## Other database drivers The following database client implementations are also available for use: @@ -144,4 +141,3 @@ The following database client implementations are also available for use: ## Contributing [Pull requests](https://github.com/feather-framework/feather-database-mysql/pulls) are welcome. Please keep changes focused and include tests for new logic. 🙏 - diff --git a/Sources/FeatherDatabaseMySQL/DatabaseClientMySQL.swift b/Sources/FeatherDatabaseMySQL/DatabaseClientMySQL.swift index 63d7c79..06a8a2e 100644 --- a/Sources/FeatherDatabaseMySQL/DatabaseClientMySQL.swift +++ b/Sources/FeatherDatabaseMySQL/DatabaseClientMySQL.swift @@ -2,12 +2,13 @@ // DatabaseClientMySQL.swift // feather-database-mysql // -// Created by Tibor Bödecs on 2026. 01. 10.. +// Created by Tibor Bödecs on 2026. 01. 10. // import FeatherDatabase import Logging import MySQLNIO +import MySQLNIOExtras /// A MySQL-backed database client. /// @@ -16,8 +17,13 @@ public struct DatabaseClientMySQL: DatabaseClient { public typealias Connection = DatabaseConnectionMySQL - var connection: DatabaseConnectionMySQL - var logger: Logger + private enum Storage { + case connection(DatabaseConnectionMySQL) + case client(MySQLClient) + } + + private let storage: Storage + private let logger: Logger /// Create a MySQL database client. /// @@ -29,16 +35,31 @@ public struct DatabaseClientMySQL: DatabaseClient { connection: MySQLConnection, logger: Logger ) { - self.connection = .init( - connection: connection, - logger: logger + self.storage = .connection( + .init( + connection: connection, + logger: logger + ) ) self.logger = logger } + /// Create a MySQL database client backed by a connection pool. + /// + /// - Parameters: + /// - client: The pooled MySQL client to use. + /// - logger: The logger for database operations. + public init( + client: MySQLClient, + logger: Logger + ) { + self.storage = .client(client) + self.logger = logger + } + // MARK: - database api - /// Execute work using the stored connection. + /// Execute work using the stored connection or pooled client. /// /// The closure is executed with the current connection. /// - Parameter closure: A closure that receives the MySQL connection. @@ -48,14 +69,27 @@ public struct DatabaseClientMySQL: DatabaseClient { public func withConnection( _ closure: (Connection) async throws -> T ) async throws(DatabaseError) -> T { - do { - return try await closure(connection) - } - catch let error as DatabaseError { - throw error - } - catch { - throw .connection(error) + switch storage { + case .connection(let connection): + return try await withConnection(connection, closure) + case .client(let client): + do { + return try await client.withConnection { connection in + try await withConnection( + DatabaseConnectionMySQL( + connection: connection, + logger: logger + ), + closure + ) + } + } + catch let error as DatabaseError { + throw error + } + catch { + throw .connection(error) + } } } @@ -69,7 +103,49 @@ public struct DatabaseClientMySQL: DatabaseClient { public func withTransaction( _ closure: (Connection) async throws -> T ) async throws(DatabaseError) -> T { + switch storage { + case .connection(let connection): + return try await withTransaction(connection, closure) + case .client(let client): + do { + return try await client.withConnection { connection in + try await withTransaction( + DatabaseConnectionMySQL( + connection: connection, + logger: logger + ), + closure + ) + } + } + catch let error as DatabaseError { + throw error + } + catch { + throw .connection(error) + } + } + } + private func withConnection( + _ connection: DatabaseConnectionMySQL, + _ closure: (Connection) async throws -> T + ) async throws(DatabaseError) -> T { + do { + return try await closure(connection) + } + catch let error as DatabaseError { + throw error + } + catch { + throw .connection(error) + } + } + + private func withTransaction( + _ connection: DatabaseConnectionMySQL, + _ closure: (Connection) async throws -> T + ) async throws(DatabaseError) -> T { do { try await connection.run(query: "START TRANSACTION;") { _ in } } @@ -118,5 +194,4 @@ public struct DatabaseClientMySQL: DatabaseClient { throw DatabaseError.transaction(txError) } } - } diff --git a/Sources/MySQLNIOExtras/MySQLClient.swift b/Sources/MySQLNIOExtras/MySQLClient.swift new file mode 100644 index 0000000..af2922b --- /dev/null +++ b/Sources/MySQLNIOExtras/MySQLClient.swift @@ -0,0 +1,141 @@ +// +// MySQLClient.swift +// feather-database-mysql +// +// Created by Binary Birds on 2026. 06. 04. +// + +import Logging +import MySQLNIO +import NIOCore +import NIOPosix +import NIOSSL + +/// A MySQL client backed by a lightweight connection pool. +/// +/// Use this client to execute queries and transactions concurrently. +public final class MySQLClient: Sendable { + + /// Configuration values for a pooled MySQL client. + public struct Configuration: Sendable { + /// The MySQL host name. + public let host: String + /// The MySQL port. + public let port: Int + /// The MySQL username. + public let username: String + /// The MySQL database name. + public let database: String + /// The MySQL password. + public let password: String? + /// TLS configuration used when opening connections. + public let tlsConfiguration: TLSConfiguration? + /// The server host name used for TLS verification. + public let serverHostname: String? + /// Minimum number of pooled connections to keep open. + public let minimumConnections: Int + /// Maximum number of pooled connections to allow. + public let maximumConnections: Int + /// Logger used for pool operations. + public let logger: Logger + /// Number of threads used by the pool's event loop group. + public let eventLoopThreads: Int + + /// Create a pooled MySQL client configuration. + /// - Parameters: + /// - host: The MySQL host name. + /// - port: The MySQL port. + /// - username: The username used when connecting. + /// - database: The database used when connecting. + /// - password: The password used when connecting. + /// - tlsConfiguration: The TLS configuration to use. + /// - serverHostname: The server host name used for TLS verification. + /// - logger: The logger for pool operations. + /// - minimumConnections: The minimum number of pooled connections. + /// - maximumConnections: The maximum number of pooled connections. + /// - eventLoopThreads: The number of event loop threads to use. + public init( + host: String, + port: Int, + username: String, + database: String, + password: String? = nil, + tlsConfiguration: TLSConfiguration? = .makeClientConfiguration(), + serverHostname: String? = nil, + logger: Logger, + minimumConnections: Int = 1, + maximumConnections: Int = System.coreCount, + eventLoopThreads: Int = 1 + ) { + precondition(port > 0) + precondition(minimumConnections >= 0) + precondition(maximumConnections >= 1) + precondition(minimumConnections <= maximumConnections) + precondition(eventLoopThreads >= 1) + + self.host = host + self.port = port + self.username = username + self.database = database + self.password = password + self.tlsConfiguration = tlsConfiguration + self.serverHostname = serverHostname + self.minimumConnections = minimumConnections + self.maximumConnections = maximumConnections + self.logger = logger + self.eventLoopThreads = eventLoopThreads + } + } + + private let pool: MySQLConnectionPool + + /// Create a MySQL client with a lightweight connection pool. + /// - Parameter configuration: The client configuration. + public init(configuration: Configuration) { + self.pool = MySQLConnectionPool(configuration: configuration) + } + + // MARK: - pool service + + /// Pre-open the minimum number of connections. + public func run() async throws { + try await pool.warmup() + } + + /// Close all pooled connections and refuse new leases. + public func shutdown() async { + await pool.shutdown() + } + + // MARK: - database api + + /// Execute work using a leased connection. + /// + /// The connection is returned to the pool when the closure completes. + /// - Parameter closure: A closure that receives a MySQL connection. + /// - Throws: A connection error if leasing or execution fails. + /// - Returns: The result produced by the closure. + @discardableResult + public func withConnection( + _ closure: (MySQLConnection) async throws -> T + ) async throws -> T { + let connection = try await leaseConnection() + do { + let result = try await closure(connection) + await pool.releaseConnection(connection) + return result + } + catch { + await pool.releaseConnection(connection) + throw error + } + } + + func connectionCount() async -> Int { + await pool.connectionCount() + } + + private func leaseConnection() async throws -> MySQLConnection { + try await pool.leaseConnection() + } +} diff --git a/Sources/MySQLNIOExtras/MySQLConnectionPool.swift b/Sources/MySQLNIOExtras/MySQLConnectionPool.swift new file mode 100644 index 0000000..52f2439 --- /dev/null +++ b/Sources/MySQLNIOExtras/MySQLConnectionPool.swift @@ -0,0 +1,259 @@ +// +// MySQLConnectionPool.swift +// feather-database-mysql +// +// Created by Binary Birds on 2026. 06. 04. +// + +import Logging +import MySQLNIO +import NIOCore +import NIOPosix +import NIOSSL + +actor MySQLConnectionPool { + + private struct Waiter { + let id: Int + let continuation: CheckedContinuation + } + + private let configuration: MySQLClient.Configuration + private let eventLoopGroup: MultiThreadedEventLoopGroup + private var availableConnections: [MySQLConnection] = [] + private var waiters: [Waiter] = [] + private var totalConnections = 0 + private var nextWaiterID = 0 + private var isShutdown = false + private var didShutdownEventLoopGroup = false + + init(configuration: MySQLClient.Configuration) { + self.configuration = configuration + self.eventLoopGroup = MultiThreadedEventLoopGroup( + numberOfThreads: configuration.eventLoopThreads + ) + } + + func warmup() async throws { + guard !isShutdown else { return } + let target = configuration.minimumConnections + guard totalConnections < target else { return } + let newConnections = target - totalConnections + + for _ in 0.. MySQLConnection { + guard !isShutdown else { + throw MySQLConnectionPoolError.shutdown + } + + while let connection = availableConnections.popLast() { + if await validateConnection(connection) { + return connection + } + + totalConnections = max(0, totalConnections - 1) + await closeConnection(connection) + } + + while totalConnections < configuration.maximumConnections { + totalConnections += 1 + do { + let connection = try await makeConnection() + if await validateConnection(connection) { + return connection + } + + totalConnections = max(0, totalConnections - 1) + await closeConnection(connection) + } + catch { + totalConnections -= 1 + throw error + } + } + + let waiterID = nextWaiterID + nextWaiterID += 1 + + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + waiters.append( + Waiter(id: waiterID, continuation: continuation) + ) + } + } onCancel: { + Task { await self.cancelWaiter(id: waiterID) } + } + } + + func releaseConnection( + _ connection: MySQLConnection + ) async { + if isShutdown { + await closeConnection(connection) + totalConnections = max(0, totalConnections - 1) + await shutdownEventLoopGroupIfNeeded() + return + } + + if connection.isClosed { + totalConnections = max(0, totalConnections - 1) + await closeConnection(connection) + await replaceConnectionForWaitingCaller() + return + } + + if await validateConnection(connection) == false { + totalConnections = max(0, totalConnections - 1) + await closeConnection(connection) + await replaceConnectionForWaitingCaller() + return + } + + if waiters.isEmpty { + availableConnections.append(connection) + return + } + + let waiter = waiters.removeFirst() + waiter.continuation.resume(returning: connection) + } + + func shutdown() async { + guard !isShutdown else { return } + isShutdown = true + + let connections = availableConnections + availableConnections.removeAll(keepingCapacity: false) + + for connection in connections { + await closeConnection(connection) + totalConnections = max(0, totalConnections - 1) + } + + for waiter in waiters { + waiter.continuation.resume( + throwing: MySQLConnectionPoolError.shutdown + ) + } + waiters.removeAll(keepingCapacity: false) + + await shutdownEventLoopGroupIfNeeded() + } + + func connectionCount() -> Int { + totalConnections + } + + private func cancelWaiter( + id: Int + ) { + guard let index = waiters.firstIndex(where: { $0.id == id }) else { + return + } + let waiter = waiters.remove(at: index) + waiter.continuation.resume(throwing: CancellationError()) + } + + private func makeConnection() async throws -> MySQLConnection { + let connection = + try await MySQLConnection.connect( + to: try SocketAddress.makeAddressResolvingHost( + configuration.host, + port: configuration.port + ), + username: configuration.username, + database: configuration.database, + password: configuration.password, + tlsConfiguration: configuration.tlsConfiguration, + serverHostname: configuration.serverHostname, + logger: configuration.logger, + on: eventLoopGroup.next() + ) + .get() + + return connection + } + + private func validateConnection( + _ connection: MySQLConnection + ) async -> Bool { + guard !connection.isClosed else { + return false + } + + do { + _ = try await connection.query("SELECT 1;", []).get() + return true + } + catch { + return false + } + } + + private func replaceConnectionForWaitingCaller() async { + guard !waiters.isEmpty else { + return + } + + let waiter = waiters.removeFirst() + + do { + let replacement = try await makeConnection() + if await validateConnection(replacement) { + totalConnections += 1 + waiter.continuation.resume(returning: replacement) + return + } + + totalConnections = max(0, totalConnections - 1) + await closeConnection(replacement) + await replaceConnectionForWaitingCaller() + } + catch { + waiter.continuation.resume(throwing: error) + } + } + + private func closeConnection( + _ connection: MySQLConnection + ) async { + do { + try await connection.close().get() + } + catch { + configuration.logger.warning( + "Failed to close MySQL connection", + metadata: [ + "error": "\(error)" + ] + ) + } + } + + private func shutdownEventLoopGroupIfNeeded() async { + guard isShutdown, totalConnections == 0, !didShutdownEventLoopGroup + else { + return + } + + do { + try await eventLoopGroup.shutdownGracefully() + } + catch { + configuration.logger.warning( + "Failed to shutdown MySQL event loop group", + metadata: [ + "error": "\(error)" + ] + ) + } + didShutdownEventLoopGroup = true + } +} diff --git a/Sources/MySQLNIOExtras/MySQLConnectionPoolError.swift b/Sources/MySQLNIOExtras/MySQLConnectionPoolError.swift new file mode 100644 index 0000000..933ec1e --- /dev/null +++ b/Sources/MySQLNIOExtras/MySQLConnectionPoolError.swift @@ -0,0 +1,10 @@ +// +// MySQLConnectionPoolError.swift +// feather-database-mysql +// +// Created by Binary Birds on 2026. 06. 04. +// + +enum MySQLConnectionPoolError: Error { + case shutdown +} diff --git a/Tests/FeatherDatabaseMySQLTests/FeatherDatabaseMySQLTestSuite.swift b/Tests/FeatherDatabaseMySQLTests/FeatherDatabaseMySQLTestSuite.swift index e7e034a..1fe8d4e 100644 --- a/Tests/FeatherDatabaseMySQLTests/FeatherDatabaseMySQLTestSuite.swift +++ b/Tests/FeatherDatabaseMySQLTests/FeatherDatabaseMySQLTestSuite.swift @@ -2,14 +2,14 @@ // FeatherDatabaseMySQLTestSuite.swift // feather-database-mysql // -// Created by Tibor Bödecs on 2026. 01. 10.. +// Created by Tibor Bödecs on 2026. 01. 10. // import FeatherDatabase import Logging import MySQLNIO +import MySQLNIOExtras import NIOCore -import NIOPosix import NIOSSL import Testing @@ -23,26 +23,19 @@ import Foundation @Suite struct FeatherDatabaseMySQLTestSuite { - - func randomTableSuffix() -> String { - let characters = Array("abcdefghijklmnopqrstuvwxyz0123456789") - var suffix = "" - suffix.reserveCapacity(16) - for _ in 0..<16 { - suffix.append(characters.randomElement() ?? "a") - } - return suffix - } - - func runUsingTestDatabaseClient( - _ closure: ((DatabaseClientMySQL) async throws -> Void) - ) async throws { + static let sharedLogger: Logger = { var logger = Logger(label: "test") logger.logLevel = .info - - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - - let finalCertPath = URL(fileURLWithPath: #filePath) + return logger + }() + + static let sharedPoolClient: MySQLClient = { + let environment = ProcessInfo.processInfo.environment + let finalCertPath = + environment["MYSQL_CA_CERT_PATH"] + ?? URL( + fileURLWithPath: #filePath + ) .deletingLastPathComponent() .deletingLastPathComponent() .deletingLastPathComponent() @@ -52,38 +45,71 @@ struct FeatherDatabaseMySQLTestSuite { .appendingPathComponent("ca.pem") .path() + let host = environment["MYSQL_HOST"] ?? "localhost" + let port = environment["MYSQL_PORT"].flatMap(Int.init) ?? 3306 + let password = environment["MYSQL_PASSWORD"] ?? "mariadb" + let database = environment["MYSQL_DATABASE"] ?? "mariadb" + var tlsConfig = TLSConfiguration.makeClientConfiguration() - let rootCert = try NIOSSLCertificate.fromPEMFile(finalCertPath) - tlsConfig.trustRoots = .certificates(rootCert) + let rootCertificates = loadRootCertificates(at: finalCertPath) + tlsConfig.trustRoots = .certificates(rootCertificates) tlsConfig.certificateVerification = .fullVerification - let connection = - try await MySQLConnection.connect( - to: try SocketAddress(ipAddress: "127.0.0.1", port: 3306), - username: "mariadb", - database: "mariadb", - password: "mariadb", + return MySQLClient( + configuration: .init( + host: host, + port: port, + username: "root", + database: database, + password: password, tlsConfiguration: tlsConfig, - logger: logger, - on: eventLoopGroup.next() + serverHostname: host, + logger: sharedLogger, + minimumConnections: 0, + maximumConnections: 4, + eventLoopThreads: 1 ) - .get() - - let database = DatabaseClientMySQL( - connection: connection, - logger: logger ) + }() + static func loadRootCertificates( + at path: String + ) -> [NIOSSLCertificate] { do { - try await closure(database) - - try await connection.close().get() - try await eventLoopGroup.shutdownGracefully() + return try NIOSSLCertificate.fromPEMFile(path) } catch { - try await connection.close().get() - try await eventLoopGroup.shutdownGracefully() + fatalError( + "Failed to load MySQL CA certificate at \(path): \(error)" + ) + } + } + func randomTableSuffix() -> String { + let characters = Array("abcdefghijklmnopqrstuvwxyz0123456789") + var suffix = "" + suffix.reserveCapacity(16) + for _ in 0..<16 { + suffix.append(characters.randomElement() ?? "a") + } + return suffix + } + + func runUsingTestDatabaseClient( + _ closure: ((DatabaseClientMySQL) async throws -> Void) + ) async throws { + let logger = Self.sharedLogger + let client = Self.sharedPoolClient + + do { + let database = DatabaseClientMySQL( + client: client, + logger: logger + ) + + try await closure(database) + } + catch { Issue.record(error) } } @@ -545,6 +571,342 @@ struct FeatherDatabaseMySQLTestSuite { } } + @Test + func boundOptionalInterpolationRoundTrip() async throws { + try await runUsingTestDatabaseClient { database in + try await database.withConnection { connection in + let boundString: String? = "alpha" + let missingString: String? = nil + let boundInt: Int? = 21 + let missingInt: Int? = nil + let boundFloat: Float? = 1.25 + let missingFloat: Float? = nil + let boundDouble: Double? = 3.75 + let missingDouble: Double? = nil + let boundBool: Bool? = true + let missingBool: Bool? = nil + + struct OptionalRow: Sendable { + let boundString: String? + let missingString: String? + let boundInt: Int? + let missingInt: Int? + let boundFloat: Double? + let missingFloat: Double? + let boundDouble: Double? + let missingDouble: Double? + let boundBool: Int? + let missingBool: Int? + + init(_ row: DatabaseRow) throws { + self.boundString = try row.decode( + column: "bound_string", + as: String?.self + ) + self.missingString = try row.decode( + column: "missing_string", + as: String?.self + ) + self.boundInt = try row.decode( + column: "bound_int", + as: Int?.self + ) + self.missingInt = try row.decode( + column: "missing_int", + as: Int?.self + ) + self.boundFloat = try row.decode( + column: "bound_float", + as: Double?.self + ) + self.missingFloat = try row.decode( + column: "missing_float", + as: Double?.self + ) + self.boundDouble = try row.decode( + column: "bound_double", + as: Double?.self + ) + self.missingDouble = try row.decode( + column: "missing_double", + as: Double?.self + ) + self.boundBool = try row.decode( + column: "bound_bool", + as: Int?.self + ) + self.missingBool = try row.decode( + column: "missing_bool", + as: Int?.self + ) + } + } + + let result = try await connection.run( + query: #""" + SELECT + \#(boundString) AS `bound_string`, + \#(missingString) AS `missing_string`, + \#(boundInt) AS `bound_int`, + \#(missingInt) AS `missing_int`, + \#(boundFloat) AS `bound_float`, + \#(missingFloat) AS `missing_float`, + \#(boundDouble) AS `bound_double`, + \#(missingDouble) AS `missing_double`, + \#(boundBool) AS `bound_bool`, + \#(missingBool) AS `missing_bool`; + """# + ) { try await $0.collect().map { try OptionalRow($0) } } + + #expect(result.count == 1) + #expect(result[0].boundString == "alpha") + #expect(result[0].missingString == nil) + #expect(result[0].boundInt == 21) + #expect(result[0].missingInt == nil) + #expect(result[0].boundFloat == 1.25) + #expect(result[0].missingFloat == nil) + #expect(result[0].boundDouble == 3.75) + #expect(result[0].missingDouble == nil) + #expect(result[0].boundBool == 1) + #expect(result[0].missingBool == nil) + } + } + } + + @Test + func unescapedOptionalInterpolationRoundTrip() async throws { + try await runUsingTestDatabaseClient { database in + try await database.withConnection { connection in + let rawString: String? = "beta" + let missingString: String? = nil + let rawInt: Int? = 7 + let missingInt: Int? = nil + let rawFloat: Float? = 2.5 + let missingFloat: Float? = nil + let rawDouble: Double? = 4.5 + let missingDouble: Double? = nil + let rawBool: Bool? = false + let missingBool: Bool? = nil + + struct RawOptionalRow: Sendable { + let rawString: String? + let missingString: String? + let rawInt: Int? + let missingInt: Int? + let rawFloat: Double? + let missingFloat: Double? + let rawDouble: Double? + let missingDouble: Double? + let rawBool: Int? + let missingBool: Int? + + init(_ row: DatabaseRow) throws { + self.rawString = try row.decode( + column: "raw_string", + as: String?.self + ) + self.missingString = try row.decode( + column: "missing_string", + as: String?.self + ) + self.rawInt = try row.decode( + column: "raw_int", + as: Int?.self + ) + self.missingInt = try row.decode( + column: "missing_int", + as: Int?.self + ) + self.rawFloat = try row.decode( + column: "raw_float", + as: Double?.self + ) + self.missingFloat = try row.decode( + column: "missing_float", + as: Double?.self + ) + self.rawDouble = try row.decode( + column: "raw_double", + as: Double?.self + ) + self.missingDouble = try row.decode( + column: "missing_double", + as: Double?.self + ) + self.rawBool = try row.decode( + column: "raw_bool", + as: Int?.self + ) + self.missingBool = try row.decode( + column: "missing_bool", + as: Int?.self + ) + } + } + + let result = try await connection.run( + query: #""" + SELECT + '\#(unescaped: rawString)' AS `raw_string`, + \#(unescaped: missingString) AS `missing_string`, + \#(unescaped: rawInt) AS `raw_int`, + \#(unescaped: missingInt) AS `missing_int`, + \#(unescaped: rawFloat) AS `raw_float`, + \#(unescaped: missingFloat) AS `missing_float`, + \#(unescaped: rawDouble) AS `raw_double`, + \#(unescaped: missingDouble) AS `missing_double`, + \#(unescaped: rawBool) AS `raw_bool`, + \#(unescaped: missingBool) AS `missing_bool`; + """# + ) { try await $0.collect().map { try RawOptionalRow($0) } } + + #expect(result.count == 1) + #expect(result[0].rawString == "beta") + #expect(result[0].missingString == nil) + #expect(result[0].rawInt == 7) + #expect(result[0].missingInt == nil) + #expect(result[0].rawFloat == 2.5) + #expect(result[0].missingFloat == nil) + #expect(result[0].rawDouble == 4.5) + #expect(result[0].missingDouble == nil) + #expect(result[0].rawBool == 0) + #expect(result[0].missingBool == nil) + } + } + } + + @Test + func arrayInterpolationRoundTrip() async throws { + try await runUsingTestDatabaseClient { database in + let suffix = randomTableSuffix() + let table = "array_samples_\(suffix)" + + try await database.withConnection { connection in + try await connection.run( + query: #""" + DROP TABLE IF EXISTS `\#(unescaped: table)`; + """# + ) + try await connection.run( + query: #""" + CREATE TABLE `\#(unescaped: table)` ( + `id` INTEGER NOT NULL PRIMARY KEY, + `name` TEXT NOT NULL, + `ratio` DOUBLE NOT NULL, + `score` DOUBLE NOT NULL, + `active` BOOLEAN NOT NULL + ); + """# + ) + + try await connection.run( + query: #""" + INSERT INTO `\#(unescaped: table)` + (`id`, `name`, `ratio`, `score`, `active`) + VALUES + (1, 'alpha', 1.5, 3.5, true), + (2, 'beta', 2.25, 4.75, false); + """# + ) + + let names = ["alpha", "omega"] + let ids = [1, 99] + let ratios: [Float] = [1.5, 9.5] + let scores = [3.5, 9.75] + let flags = [true, false] + + let result = try await connection.run( + query: #""" + SELECT + `id`, + `name` + FROM `\#(unescaped: table)` + WHERE + `name` IN (\#(names)) + AND `id` IN (\#(ids)) + AND `ratio` IN (\#(ratios)) + AND `score` IN (\#(scores)) + AND `active` IN (\#(flags)) + ORDER BY `id`; + """# + ) { try await $0.collect() } + + #expect(result.count == 1) + #expect(try result[0].decode(column: "id", as: Int.self) == 1) + #expect( + try result[0].decode(column: "name", as: String.self) + == "alpha" + ) + } + } + } + + @Test + func optionalArrayInterpolationRoundTrip() async throws { + try await runUsingTestDatabaseClient { database in + let suffix = randomTableSuffix() + let table = "optional_array_samples_\(suffix)" + + try await database.withConnection { connection in + try await connection.run( + query: #""" + DROP TABLE IF EXISTS `\#(unescaped: table)`; + """# + ) + try await connection.run( + query: #""" + CREATE TABLE `\#(unescaped: table)` ( + `id` INTEGER NOT NULL PRIMARY KEY, + `name` TEXT NOT NULL, + `ratio` DOUBLE NOT NULL, + `score` DOUBLE NOT NULL, + `active` BOOLEAN NOT NULL + ); + """# + ) + + try await connection.run( + query: #""" + INSERT INTO `\#(unescaped: table)` + (`id`, `name`, `ratio`, `score`, `active`) + VALUES + (1, 'alpha', 1.5, 3.5, true), + (2, 'beta', 2.25, 4.75, false); + """# + ) + + let names: [String?] = ["alpha", nil, "omega"] + let ids: [Int?] = [1, nil, 99] + let ratios: [Float?] = [1.5, nil, 9.5] + let scores: [Double?] = [3.5, nil, 9.75] + let flags: [Bool?] = [true, nil, false] + + let result = try await connection.run( + query: #""" + SELECT + `id`, + `name` + FROM `\#(unescaped: table)` + WHERE + `name` IN (\#(names)) + AND `id` IN (\#(ids)) + AND `ratio` IN (\#(ratios)) + AND `score` IN (\#(scores)) + AND `active` IN (\#(flags)) + ORDER BY `id`; + """# + ) { try await $0.collect() } + + #expect(result.count == 1) + #expect(try result[0].decode(column: "id", as: Int.self) == 1) + #expect( + try result[0].decode(column: "name", as: String.self) + == "alpha" + ) + } + } + } + @Test func booleanInterpolation() async throws { try await runUsingTestDatabaseClient { database in diff --git a/Tests/MySQLNIOExtrasTests/MySQLNIOExtrasTestSuite.swift b/Tests/MySQLNIOExtrasTests/MySQLNIOExtrasTestSuite.swift new file mode 100644 index 0000000..baad810 --- /dev/null +++ b/Tests/MySQLNIOExtrasTests/MySQLNIOExtrasTestSuite.swift @@ -0,0 +1,231 @@ +// +// MySQLNIOExtrasTestSuite.swift +// feather-database-mysql +// +// Created by Binary Birds on 2026. 06. 05. +// + +import Logging +import MySQLNIO +import NIOCore +import NIOSSL +import Testing + +@testable import MySQLNIOExtras + +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +@Suite +struct MySQLNIOExtrasTestSuite { + + actor Latch { + private var isSignaled = false + private var waiters: [CheckedContinuation] = [] + + func wait() async { + if isSignaled { + return + } + + await withCheckedContinuation { continuation in + if isSignaled { + continuation.resume() + return + } + + waiters.append(continuation) + } + } + + func signal() { + guard !isSignaled else { + return + } + + isSignaled = true + let pendingWaiters = waiters + waiters.removeAll(keepingCapacity: false) + for waiter in pendingWaiters { + waiter.resume() + } + } + } + + actor Flag { + private var value = false + + func set() { + value = true + } + + func get() -> Bool { + value + } + } + + static let sharedLogger: Logger = { + var logger = Logger(label: "test") + logger.logLevel = .info + return logger + }() + + static func loadRootCertificates(at path: String) -> [NIOSSLCertificate] { + do { + return try NIOSSLCertificate.fromPEMFile(path) + } + catch { + fatalError( + "Failed to load MySQL CA certificate at \(path): \(error)" + ) + } + } + + static func makeClient( + maximumConnections: Int + ) -> MySQLClient { + let environment = ProcessInfo.processInfo.environment + let finalCertPath = + environment["MYSQL_CA_CERT_PATH"] + ?? URL( + fileURLWithPath: #filePath + ) + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + .appendingPathComponent("docker") + .appendingPathComponent("mariadb") + .appendingPathComponent("certificates") + .appendingPathComponent("ca.pem") + .path() + + let host = environment["MYSQL_HOST"] ?? "localhost" + let port = environment["MYSQL_PORT"].flatMap(Int.init) ?? 3306 + let password = environment["MYSQL_PASSWORD"] ?? "mariadb" + + var tlsConfig = TLSConfiguration.makeClientConfiguration() + tlsConfig.trustRoots = .certificates( + loadRootCertificates(at: finalCertPath) + ) + tlsConfig.certificateVerification = .fullVerification + + return MySQLClient( + configuration: .init( + host: host, + port: port, + username: "root", + database: environment["MYSQL_DATABASE"] ?? "mariadb", + password: password, + tlsConfiguration: tlsConfig, + serverHostname: host, + logger: sharedLogger, + minimumConnections: 0, + maximumConnections: maximumConnections, + eventLoopThreads: 1 + ) + ) + } + + func withClient( + maximumConnections: Int = 1, + _ closure: (MySQLClient) async throws -> T + ) async throws -> T { + let client = Self.makeClient(maximumConnections: maximumConnections) + try await client.run() + do { + let result = try await closure(client) + await client.shutdown() + return result + } + catch { + await client.shutdown() + throw error + } + } + + @Test + func closedConnectionIsReplacedForWaitingCaller() async throws { + try await withClient { client in + let firstAcquired = Latch() + let allowFirstToFinish = Latch() + + let first = Task { + try await client.withConnection { connection in + await firstAcquired.signal() + await allowFirstToFinish.wait() + try await connection.close().get() + } + } + + await firstAcquired.wait() + + let second = Task { + try await client.withConnection { connection in + _ = try await connection.query("SELECT 1;", []).get() + } + } + + await allowFirstToFinish.signal() + try await first.value + try await second.value + + #expect(await client.connectionCount() == 1) + } + } + + @Test + func cancelledWaiterDoesNotBlockQueue() async throws { + try await withClient { client in + let firstAcquired = Latch() + let allowFirstToFinish = Latch() + let secondStarted = Latch() + let didRunSecondClosure = Flag() + + let first = Task { + try await client.withConnection { connection in + await firstAcquired.signal() + await allowFirstToFinish.wait() + _ = try await connection.query("SELECT 1;", []).get() + } + } + + await firstAcquired.wait() + + let second = Task { + await secondStarted.signal() + return try await client.withConnection { _ in + await didRunSecondClosure.set() + } + } + + await secondStarted.wait() + try await Task.sleep(for: .milliseconds(50)) + second.cancel() + + await allowFirstToFinish.signal() + try await first.value + + do { + try await second.value + Issue.record("Expected the waiting task to be cancelled.") + } + catch is CancellationError { + // Expected. + } + catch { + Issue.record("Expected cancellation, got \(error).") + } + + #expect(await didRunSecondClosure.get() == false) + + try await client.withConnection { connection in + _ = try await connection.query("SELECT 1;", []).get() + } + + #expect(await client.connectionCount() == 1) + } + } +} diff --git a/docker-compose.yaml b/docker-compose.yaml index 746d43e..1131195 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -3,15 +3,44 @@ services: build: context: . dockerfile: docker/mariadb/Dockerfile + volumes: + - 'mariadb:/var/lib/mysql' + ports: + - '3306:3306' + restart: always environment: MYSQL_ROOT_PASSWORD: mariadb MYSQL_USER: mariadb MYSQL_PASSWORD: mariadb MYSQL_DATABASE: mariadb - ports: - - "3306:3306" + healthcheck: + test: + - CMD-SHELL + - mariadb-admin --protocol=tcp --host=127.0.0.1 --user=root --password=$${MYSQL_ROOT_PASSWORD} ping --silent + interval: 2s + timeout: 5s + retries: 30 + start_period: 20s + + test: + build: + context: . + dockerfile: docker/tests/Dockerfile + depends_on: + mariadb: + condition: service_healthy + environment: + MYSQL_HOST: mariadb + MYSQL_PORT: 3306 + MYSQL_ROOT_PASSWORD: mariadb + MYSQL_USER: mariadb + MYSQL_PASSWORD: mariadb + MYSQL_DATABASE: mariadb + MYSQL_CA_CERT_PATH: /app/docker/mariadb/certificates/ca.pem volumes: - - 'mariadb:/docker-entrypoint-initdb.d:ro' + - .:/app + working_dir: /app + command: ["swift", "test", "--parallel", "--enable-code-coverage"] volumes: mariadb: diff --git a/docker/mariadb/scripts/generate-certificates.sh b/docker/mariadb/scripts/generate-certificates.sh index 2fbb9af..8d43519 100755 --- a/docker/mariadb/scripts/generate-certificates.sh +++ b/docker/mariadb/scripts/generate-certificates.sh @@ -4,7 +4,7 @@ openssl genpkey -algorithm RSA -out ca.key # Generate CA certificate (PEM format) -openssl req -new -x509 -key ca.key -out ca.pem -days 365 -subj "/CN=PostgreSQL-CA" +openssl req -new -x509 -key ca.key -out ca.pem -days 365 -subj "/CN=MariaDB-CA" # Generate server private key openssl genpkey -algorithm RSA -out server.key @@ -12,8 +12,8 @@ openssl genpkey -algorithm RSA -out server.key # Create a CSR with correct CN (Common Name) & SAN (Subject Alternative Name) openssl req -new -key server.key -out server.csr -subj "/CN=localhost" -# SAN for localhost, db, and 127.0.0.1 -echo "subjectAltName=DNS:localhost,DNS:db,IP:127.0.0.1" > san.cnf +# SAN for localhost, mariadb, and 127.0.0.1 +echo "subjectAltName=DNS:localhost,DNS:mariadb,IP:127.0.0.1" > san.cnf # Sign the server certificate with CA & SAN openssl x509 -req -in server.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out server.pem -days 365 -extfile san.cnf