diff --git a/dist/index.d.ts b/dist/index.d.ts index 4f1588d..ba45cd2 100644 --- a/dist/index.d.ts +++ b/dist/index.d.ts @@ -1,5 +1,5 @@ -import * as pg from "pg"; import TypedEventEmitter from "typed-emitter"; +import pg = require("pg"); export interface PgParsedNotification { processId: number; channel: string; diff --git a/src/index.ts b/src/index.ts index c1a91d3..601cd90 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,6 +8,7 @@ import pg = require("pg") const connectionLogger = createDebugLogger("pg-listen:connection") const notificationLogger = createDebugLogger("pg-listen:notification") +const paranoidLogger = createDebugLogger("pg-listen:paranoid") const subscriptionLogger = createDebugLogger("pg-listen:subscription") const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)) @@ -84,7 +85,14 @@ function connect (connectionConfig: pg.ClientConfig | undefined, options: Option try { const newClient = new Client(effectiveConnectionConfig) + const connecting = new Promise((resolve, reject) => { + newClient.once("connect", resolve) + newClient.once("end", () => reject(Error("Connection ended."))) + newClient.once("error", reject) + }) await newClient.connect() + await connecting + connectionLogger("PostgreSQL reconnection succeeded") return newClient } catch (error) { connectionLogger("PostgreSQL reconnection attempt failed:", error) @@ -133,8 +141,10 @@ function forwardDBNotificationEvents (dbClient: pg.Client, emitter: TypedEventEm function scheduleParanoidChecking (dbClient: pg.Client, intervalTime: number, reconnect: () => Promise<void>) { const scheduledCheck = async () => { try { - await dbClient.query("SELECT 1") + await dbClient.query("SELECT pg_backend_pid()") + paranoidLogger("Paranoid connection check ok") } catch (error) { + paranoidLogger("Paranoid connection check failed") connectionLogger("Paranoid connection check failed:", error) await reconnect() } @@ -179,22 +189,25 @@ function createPostgresSubscriber (connectionConfig?: pg.ClientConfig, options: let closing = false let dbClient = initialDBClient + let reinitializingRightNow = false let subscribedChannels: string[] = [] let cancelEventForwarding: () => void = () => undefined let cancelParanoidChecking: () => void = () => undefined - const initialize = async (client: pg.Client) => { + const initialize = (client: pg.Client) => { // Wire the DB client events to our exposed emitter's events cancelEventForwarding = forwardDBNotificationEvents(client, emitter) dbClient.on("error", (error: any) => { - connectionLogger("DB Client error:", error) - reinitialize() + if (!reinitializingRightNow) { + connectionLogger("DB Client error:", error) + reinitialize() + } }) dbClient.on("end", () => { - connectionLogger("DB Client connection ended") - if (!closing) { + if (!reinitializingRightNow) { + connectionLogger("DB Client connection ended") reinitialize() } }) @@ -206,20 +219,32 @@ function createPostgresSubscriber (connectionConfig?: pg.ClientConfig, options: // No need to handle errors when calling `reinitialize()`, it handles its errors itself const reinitialize = async () => { + if (reinitializingRightNow || closing) { + return + } + reinitializingRightNow = true + try { cancelParanoidChecking() cancelEventForwarding() + dbClient.removeAllListeners() + dbClient.once("error", error => connectionLogger(`Previous DB client errored after reconnecting already:`, error)) + dbClient.end() + dbClient = await reconnect(attempt => emitter.emit("reconnect", attempt)) - await initialize(dbClient) + initialize(dbClient) + subscriptionLogger(`Re-subscribing to channels: ${subscribedChannels.join(", ")}`) await Promise.all(subscribedChannels.map( - channelName => `LISTEN ${format.ident(channelName)}` + channelName => dbClient.query(`LISTEN ${format.ident(channelName)}`) )) } catch (error) { error.message = `Re-initializing the PostgreSQL notification client after connection loss failed: ${error.message}` connectionLogger(error.stack || error) emitter.emit("error", error) + } finally { + reinitializingRightNow = false } } diff --git a/test/integration.test.ts b/test/integration.test.ts index 0e8c8cd..096f686 100644 --- a/test/integration.test.ts +++ b/test/integration.test.ts @@ -1,6 +1,11 @@ import test from "ava" +import DebugLogger from "debug" import createPostgresSubscriber, { PgParsedNotification } from "../src/index" +// Need to require `pg` like this to avoid ugly error message +import pg = require("pg") + +const debug = DebugLogger("pg-listen:test") const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)) test("can connect", async t => { @@ -41,3 +46,56 @@ test("can listen and notify", async t => { await hub.close() } }) + +test("getting notified after connection is terminated", async t => { + const notifications: PgParsedNotification[] = [] + const receivedPayloads: any[] = [] + let reconnects = 0 + + const connectionString = "postgres://postgres:postgres@localhost:5432/postgres" + let client = new pg.Client({ connectionString }) + await client.connect() + + const hub = createPostgresSubscriber( + { connectionString: connectionString + "?ApplicationName=pg-listen-termination-test" }, + { paranoidChecking: 1000 } + ) + await hub.connect() + + try { + await hub.listenTo("test") + hub.events.on("notification", (notification: PgParsedNotification) => notifications.push(notification)) + hub.events.on("reconnect", () => reconnects++) + hub.notifications.on("test", (payload: any) => receivedPayloads.push(payload)) + + await delay(1000) + debug("Terminating database backend") + + // Don't await as we kill some other connection, so the promise won't resolve (I think) + client.query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND usename = current_user") + await delay(2000) + + client = new pg.Client({ connectionString }) + await client.connect() + + debug("Sending notification...") + await client.query(`NOTIFY test, '{"hello": "world"}';`) + await delay(500) + + t.deepEqual(hub.getSubscribedChannels(), ["test"]) + t.deepEqual(notifications, [ + { + channel: "test", + payload: { hello: "world" }, + processId: notifications[0] ? notifications[0].processId : 0 + } + ]) + t.deepEqual(receivedPayloads, [ + { hello: "world" } + ]) + t.is(reconnects, 1) + } finally { + debug("Closing the subscriber") + await hub.close() + } +})