diff --git a/README.md b/README.md index 20e3127..5eb159c 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,8 @@ npm install pgstrap --save-dev - `npm run db:migrate` - Run pending migrations - `npm run db:reset` - Drop and recreate the database, then run all migrations -- `npm run db:generate` - Generate types and structure dumps. Use `pgstrap generate --pglite` to run migrations against an in-memory PGlite instance. +- `npm run db:generate` - Generate types and structure dumps with an in-memory PGlite database, so Postgres does not need to be running in the background. +- `pgstrap generate` - Generate types and structure dumps against your configured Postgres database. - `npm run db:create-migration` - Create a new migration file ### Configuration diff --git a/src/generate.ts b/src/generate.ts index f337094..19c4478 100644 --- a/src/generate.ts +++ b/src/generate.ts @@ -8,6 +8,14 @@ import { dumpTree } from "pg-schema-dump" import path from "path" import { migrate } from "./migrate" +const closeServer = (server: import("node:net").Server) => + new Promise((resolve, reject) => { + server.close((error?: Error) => { + if (error) reject(error) + else resolve() + }) + }) + export const generate = async ({ schemas, defaultDatabase, @@ -27,65 +35,98 @@ export const generate = async ({ const net = await import("node:net") const db = new PGlite() + const prevDbUrl = process.env.DATABASE_URL + let server: import("node:net").Server | undefined + let shouldRestoreDbUrl = false + let generateError: unknown - await migrate({ - client: db as any, - migrationsDir, - defaultDatabase, - cwd: process.cwd(), - schemas, - }) + try { + await migrate({ + client: db as any, + migrationsDir, + defaultDatabase, + cwd: process.cwd(), + schemas, + }) - const server = net.createServer(async (socket) => { - const connection = await fromNodeSocket(socket, { - serverVersion: "16.3 (PGlite)", - auth: { - method: "password", - validateCredentials: ({ username, password }: any) => - username === "postgres" && password === "postgres", - getClearTextPassword: () => "postgres", - }, - async onStartup() { - await (db as any).waitReady - }, - async onMessage(data: Uint8Array, { isAuthenticated }: any) { - if (!isAuthenticated) return - try { - const { data: responseData } = await (db as any).execProtocol(data) - return responseData - } catch { - return undefined - } + const gatewayServer = net.createServer(async (socket) => { + await fromNodeSocket(socket, { + serverVersion: "16.3 (PGlite)", + auth: { + method: "password", + validateCredentials: ({ username, password }: any) => + username === "postgres" && password === "postgres", + getClearTextPassword: () => "postgres", + }, + async onStartup() { + await (db as any).waitReady + }, + async onMessage(data: Uint8Array, { isAuthenticated }: any) { + if (!isAuthenticated) return + try { + const { data: responseData } = await (db as any).execProtocol( + data, + ) + return responseData + } catch { + return undefined + } + }, + }) + }) + server = gatewayServer + + await new Promise((resolve) => gatewayServer.listen(0, resolve)) + const port = (gatewayServer.address() as any).port + const connectionString = `postgres://postgres:postgres@127.0.0.1:${port}/postgres` + + process.env.DATABASE_URL = connectionString + shouldRestoreDbUrl = true + + await zg.generate({ + db: { + connectionString, }, + schemas: Object.fromEntries( + schemas.map((s) => [s, { include: "*", exclude: [] }]), + ), + outDir: dbDir, }) - }) - await new Promise((resolve) => server.listen(0, resolve)) - const port = (server.address() as any).port - const connectionString = `postgres://postgres:postgres@127.0.0.1:${port}/postgres` + await dumpTree({ + targetDir: path.join(dbDir, "structure"), + defaultDatabase: "postgres", + schemas, + }) + } catch (error) { + generateError = error + throw error + } finally { + if (shouldRestoreDbUrl) { + if (prevDbUrl === undefined) delete process.env.DATABASE_URL + else process.env.DATABASE_URL = prevDbUrl + } - const prevDbUrl = process.env.DATABASE_URL - process.env.DATABASE_URL = connectionString - - await zg.generate({ - db: { - connectionString, - }, - schemas: Object.fromEntries( - schemas.map((s) => [s, { include: "*", exclude: [] }]), - ), - outDir: dbDir, - }) + const closePglite = (db as { close?: () => void | Promise }).close + let cleanupError: unknown - await dumpTree({ - targetDir: path.join(dbDir, "structure"), - defaultDatabase: "postgres", - schemas, - }) + try { + if (server?.listening) await closeServer(server) + } catch (error) { + cleanupError = error + } + + try { + if (typeof closePglite === "function") await closePglite.call(db) + } catch (error) { + cleanupError ??= error + } + + if (generateError === undefined && cleanupError !== undefined) { + throw cleanupError + } + } - server.close() - if (prevDbUrl === undefined) delete process.env.DATABASE_URL - else process.env.DATABASE_URL = prevDbUrl return } diff --git a/src/init.ts b/src/init.ts index b84f9cd..e396f28 100644 --- a/src/init.ts +++ b/src/init.ts @@ -16,7 +16,7 @@ export const initPgstrap = async (ctx: Pick) => { pkg.scripts["db:migrate"] = "pgstrap migrate" pkg.scripts["db:reset"] = "pgstrap reset" - pkg.scripts["db:generate"] = "pgstrap generate" + pkg.scripts["db:generate"] = "pgstrap generate --pglite" pkg.scripts["db:create-migration"] = "pgstrap create-migration" if (!pkg.devDependencies) pkg.devDependencies = {} diff --git a/tests/generate.pglite.test.ts b/tests/generate.pglite.test.ts index 56dcd53..5cb5d1b 100644 --- a/tests/generate.pglite.test.ts +++ b/tests/generate.pglite.test.ts @@ -45,3 +45,42 @@ test("generate with pglite runs migrations and dumps structure", async () => { fs.rmSync(tmp, { recursive: true, force: true }) }) + +test("generate with pglite restores DATABASE_URL after generation failure", async () => { + const tmp = fs.mkdtempSync(path.join(os.tmpdir(), "pgstrap-generate-")) + const migrationsDir = path.join(tmp, "migrations") + const dbDir = path.join(tmp, "db") + const prevDbUrl = process.env.DATABASE_URL + fs.mkdirSync(migrationsDir, { recursive: true }) + fs.writeFileSync( + path.join(migrationsDir, "001_create_table.js"), + migrationFile, + ) + fs.writeFileSync(dbDir, "not a directory") + + process.env.DATABASE_URL = "postgres://existing:secret@localhost:5432/app" + + try { + let error: unknown + try { + await generate({ + schemas: ["public"], + defaultDatabase: "postgres", + dbDir, + migrationsDir, + pglite: true, + }) + } catch (cause) { + error = cause + } + + expect(error).toBeDefined() + expect(process.env.DATABASE_URL).toBe( + "postgres://existing:secret@localhost:5432/app", + ) + } finally { + if (prevDbUrl === undefined) delete process.env.DATABASE_URL + else process.env.DATABASE_URL = prevDbUrl + fs.rmSync(tmp, { recursive: true, force: true }) + } +}) diff --git a/tests/init.test.ts b/tests/init.test.ts index cd4ec4b..aabc7f7 100644 --- a/tests/init.test.ts +++ b/tests/init.test.ts @@ -25,6 +25,6 @@ test("initPgstrap writes scripts to package.json", async () => { ) expect(pkg.scripts["db:migrate"]).toBe("pgstrap migrate") expect(pkg.scripts["db:reset"]).toBe("pgstrap reset") - expect(pkg.scripts["db:generate"]).toBe("pgstrap generate") + expect(pkg.scripts["db:generate"]).toBe("pgstrap generate --pglite") expect(pkg.scripts["db:create-migration"]).toBe("pgstrap create-migration") })