Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Plugins/PostgreSQLDriverPlugin/CockroachPluginDriver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,11 @@ final class CockroachPluginDriver: LibPQBackedDriver, @unchecked Sendable {
}

func createDatabase(_ request: PluginCreateDatabaseRequest) async throws {
let quotedName = request.name.replacingOccurrences(of: "\"", with: "\"\"")
_ = try await execute(query: "CREATE DATABASE \"\(quotedName)\"")
_ = try await execute(query: "CREATE DATABASE \(quoteIdentifier(request.name))")
}

func dropDatabase(name: String) async throws {
let quotedName = name.replacingOccurrences(of: "\"", with: "\"\"")
_ = try await execute(query: "DROP DATABASE \"\(quotedName)\"")
_ = try await execute(query: "DROP DATABASE \(quoteIdentifier(name))")
}

// MARK: - Query Helpers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import TableProPluginKit

extension PostgreSQLPluginDriver {
func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] {
let safeSchema = escapeLiteralForColumns(currentSchema ?? "public")
let safeTable = escapeLiteralForColumns(table)
let safeSchema = escapeStringLiteral(currentSchema ?? "public")
let safeTable = escapeStringLiteral(table)
let enumMap = try await fetchEnumLabelMap(schema: safeSchema)
let caps = versionedCapabilities
let identityProjection = caps.hasIdentityColumns ? "a.attidentity" : "NULL::text"
Expand Down Expand Up @@ -60,7 +60,7 @@ extension PostgreSQLPluginDriver {
}

func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] {
let safeSchema = escapeLiteralForColumns(currentSchema ?? "public")
let safeSchema = escapeStringLiteral(currentSchema ?? "public")
let enumMap = try await fetchEnumLabelMap(schema: safeSchema)
let caps = versionedCapabilities
let identityProjection = caps.hasIdentityColumns ? "a.attidentity" : "NULL::text"
Expand Down Expand Up @@ -134,10 +134,6 @@ extension PostgreSQLPluginDriver {
return map
}

fileprivate func escapeLiteralForColumns(_ str: String) -> String {
str.replacingOccurrences(of: "'", with: "''")
}

fileprivate func mapPgColumnRow(
_ row: [PluginCellValue],
tableNameOffset: Int,
Expand Down
17 changes: 8 additions & 9 deletions Plugins/PostgreSQLDriverPlugin/PostgreSQLPluginDriver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {

func fetchTableDDL(table: String, schema: String?) async throws -> String {
let safeTable = escapeLiteral(table)
let quotedTable = "\"\(table.replacingOccurrences(of: "\"", with: "\"\""))\""
let quotedTable = quoteIdentifier(table)
let caps = versionedCapabilities

let identityClause: String = caps.hasIdentityColumns ? """
Expand Down Expand Up @@ -431,7 +431,7 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {
var parts = columnDefs
parts.append(contentsOf: constraints)

let quotedSchema = "\"\(core.currentSchema.replacingOccurrences(of: "\"", with: "\"\""))\""
let quotedSchema = quoteIdentifier(core.currentSchema)
let ddl = "CREATE TABLE \(quotedSchema).\(quotedTable) (\n " +
parts.joined(separator: ",\n ") +
"\n);"
Expand Down Expand Up @@ -599,9 +599,9 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {
let incrementBy = row[4].asText ?? "1"
let cycle = row[5].asText == "t" ? " CYCLE" : ""
let lastValue = row.count > 6 ? row[6].asText : nil
let quotedSeqName = "\"\(seqName.replacingOccurrences(of: "\"", with: "\"\""))\""
let escapedSchemaForLiteral = schemaName.replacingOccurrences(of: "'", with: "''")
let escapedSeqForLiteral = seqName.replacingOccurrences(of: "'", with: "''")
let quotedSeqName = quoteIdentifier(seqName)
let escapedSchemaForLiteral = escapeStringLiteral(schemaName)
let escapedSeqForLiteral = escapeStringLiteral(seqName)
var ddl = "CREATE SEQUENCE \(quotedSeqName) INCREMENT BY \(incrementBy)"
+ " MINVALUE \(minVal) MAXVALUE \(maxVal)"
+ " START WITH \(startVal)\(cycle);"
Expand Down Expand Up @@ -692,7 +692,7 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {
}

func createDatabase(_ request: PluginCreateDatabaseRequest) async throws {
let quotedName = request.name.replacingOccurrences(of: "\"", with: "\"\"")
let quotedName = quoteIdentifier(request.name)

guard let encoding = request.values["encoding"] else {
throw LibPQPluginError(
Expand All @@ -709,7 +709,7 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {
)
}

var sql = "CREATE DATABASE \"\(quotedName)\" ENCODING '\(encoding)'"
var sql = "CREATE DATABASE \(quotedName) ENCODING '\(encoding)'"

let supportsProvider = versionedCapabilities.hasDatabaseICULocale
let provider = supportsProvider ? (request.values["provider"] ?? "libc") : "libc"
Expand Down Expand Up @@ -789,8 +789,7 @@ final class PostgreSQLPluginDriver: LibPQBackedDriver, @unchecked Sendable {
}

func dropDatabase(name: String) async throws {
let escapedName = name.replacingOccurrences(of: "\"", with: "\"\"")
_ = try await execute(query: "DROP DATABASE \"\(escapedName)\"")
_ = try await execute(query: "DROP DATABASE \(quoteIdentifier(name))")
}

private struct Template1Defaults {
Expand Down
7 changes: 5 additions & 2 deletions Plugins/PostgreSQLDriverPlugin/PostgreSQLSchemaQueries.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ enum PostgreSQLSchemaQueries {
}

static func setSearchPath(toSchema schema: String) -> String {
let quotedIdentifier = schema.replacingOccurrences(of: "\"", with: "\"\"")
return "SET search_path TO \"\(quotedIdentifier)\", public"
let quotedIdentifier = "\"\(schema.replacingOccurrences(of: "\"", with: "\"\""))\""
guard schema != "public" else {
return "SET search_path TO \(quotedIdentifier)"
}
return "SET search_path TO \(quotedIdentifier), public"
}
}
11 changes: 4 additions & 7 deletions Plugins/PostgreSQLDriverPlugin/RedshiftPluginDriver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ final class RedshiftPluginDriver: LibPQBackedDriver, @unchecked Sendable {

func fetchTableDDL(table: String, schema: String?) async throws -> String {
let safeTable = escapeLiteral(table)
let quotedTable = "\"\(table.replacingOccurrences(of: "\"", with: "\"\""))\""
let quotedSchema = "\"\(core.currentSchema.replacingOccurrences(of: "\"", with: "\"\""))\""
let quotedTable = quoteIdentifier(table)
let quotedSchema = quoteIdentifier(core.currentSchema)

do {
let showResult = try await execute(query: "SHOW TABLE \(quotedSchema).\(quotedTable)")
Expand Down Expand Up @@ -508,14 +508,12 @@ final class RedshiftPluginDriver: LibPQBackedDriver, @unchecked Sendable {
)
}

let quotedName = request.name.replacingOccurrences(of: "\"", with: "\"\"")
let sql = "CREATE DATABASE \"\(quotedName)\" COLLATE \(collate)"
let sql = "CREATE DATABASE \(quoteIdentifier(request.name)) COLLATE \(collate)"
_ = try await execute(query: sql)
}

func dropDatabase(name: String) async throws {
let escapedName = name.replacingOccurrences(of: "\"", with: "\"\"")
_ = try await execute(query: "DROP DATABASE \"\(escapedName)\"")
_ = try await execute(query: "DROP DATABASE \(quoteIdentifier(name))")
}

// MARK: - All Tables Metadata
Expand All @@ -537,5 +535,4 @@ final class RedshiftPluginDriver: LibPQBackedDriver, @unchecked Sendable {
ORDER BY "table"
"""
}

}
8 changes: 8 additions & 0 deletions TableProTests/Plugins/PostgreSQLSearchPathTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ struct PostgreSQLSearchPathTests {
)
}

@Test("omits the redundant public fallback when public is the selected schema")
func publicSchema() {
#expect(
PostgreSQLSchemaQueries.setSearchPath(toSchema: "public")
== "SET search_path TO \"public\""
)
}

@Test("preserves mixed-case schema names with identifier quoting")
func mixedCaseSchema() {
#expect(
Expand Down
Loading