diff --git a/supabase/functions/_backend/triggers/queue_consumer.ts b/supabase/functions/_backend/triggers/queue_consumer.ts index 0adaca355d..bbf2a28c4f 100644 --- a/supabase/functions/_backend/triggers/queue_consumer.ts +++ b/supabase/functions/_backend/triggers/queue_consumer.ts @@ -18,6 +18,7 @@ const MANIFEST_QUEUE_BATCH_SIZE = 100 const DEFAULT_QUEUE_HTTP_CONCURRENCY = 25 const MANIFEST_QUEUE_HTTP_CONCURRENCY = 10 const QUEUE_HTTP_TIMEOUT_MS = 15_000 +const HEALTHCHECK_HTTP_TIMEOUT_MS = 8_000 export const MAX_QUEUE_READS = 5 const DISCORD_IGNORED_ERROR_CODES = new Set(['version_not_found', 'no_channel']) @@ -77,6 +78,16 @@ interface ProcessedQueueMessage extends Message { targetUrl: string | null } +interface QueueProcessResult { + archivedCount: number + failedCount: number + processedCount: number + readSucceeded: boolean + skippedCount: number + success: boolean + successCount: number +} + function extractMessageBody(message: Message): Record { if (message.message?.payload !== undefined) return (message.message.payload ?? {}) as Record @@ -401,12 +412,24 @@ async function mapWithConcurrency( return results } -async function processQueue(c: Context, db: ReturnType, queueName: string, batchSize: number = DEFAULT_BATCH_SIZE) { +async function processQueue(c: Context, db: ReturnType, queueName: string, batchSize: number = DEFAULT_BATCH_SIZE): Promise { const messages = await readQueue(c, db, queueName, batchSize) - if (!messages) { - cloudlog(`[${queueName}] No messages found in queue or an error occurred.`) - return + if (messages === null) { + cloudlog({ + requestId: c.get('requestId'), + message: `[${queueName}] Queue read failed.`, + queueName, + }) + return { + archivedCount: 0, + failedCount: 0, + processedCount: 0, + readSucceeded: false, + skippedCount: 0, + success: false, + successCount: 0, + } } const [messagesToProcess, messagesToSkip] = messages.reduce((acc, message) => { @@ -585,6 +608,16 @@ async function processQueue(c: Context, db: ReturnType, queu else { cloudlog({ requestId: c.get('requestId'), message: `[${queueName}] All messages were processed successfully.` }) } + + return { + archivedCount: messagesToSkip.length, + failedCount: messagesFailed.length, + processedCount: messagesToProcess.length, + readSucceeded: true, + skippedCount: messagesToSkip.length, + success: messagesToSkip.length === 0 && messagesFailed.length === 0, + successCount: successMessages.length, + } } async function extractErrorDetails(response: Response): Promise<{ @@ -643,7 +676,7 @@ async function extractErrorDetails(response: Response): Promise<{ } // Reads messages from the queue and logs them -async function readQueue(c: Context, db: ReturnType, queueName: string, batchSize: number = DEFAULT_BATCH_SIZE): Promise { +async function readQueue(c: Context, db: ReturnType, queueName: string, batchSize: number = DEFAULT_BATCH_SIZE): Promise { const queueKey = 'readQueue' const startTime = Date.now() let messages: Message[] = [] @@ -684,7 +717,7 @@ async function readQueue(c: Context, db: ReturnType, queueNa finally { cloudlog({ requestId: c.get('requestId'), message: `[${queueKey}] Finished reading queue messages in ${Date.now() - startTime}ms.` }) } - return messages + return null } // The main HTTP POST helper function @@ -738,6 +771,55 @@ export async function http_post_helper( } } +async function pingCronHealthcheck( + healthcheckUrl: string, + fetchImpl: typeof fetch, +): Promise { + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), HEALTHCHECK_HTTP_TIMEOUT_MS) + + try { + const response = await fetchImpl(healthcheckUrl, { + method: 'GET', + signal: controller.signal, + }) + await response.body?.cancel() + return response.ok + } + catch { + return false + } + finally { + clearTimeout(timeoutId) + } +} + +async function maybePingCronHealthcheck( + db: ReturnType, + queueName: string, + processResult: QueueProcessResult, + healthcheckUrl: string | null, + fetchImpl: typeof fetch = fetch, +): Promise { + if (!healthcheckUrl || !processResult.success) + return false + + try { + const metrics = await db.query<{ queue_length: string }>( + 'SELECT queue_length::text FROM pgmq.metrics($1)', + [queueName], + ) + const queueLength = Number(metrics.rows[0]?.queue_length ?? 0) + if (queueLength !== 0) + return false + + return pingCronHealthcheck(healthcheckUrl, fetchImpl) + } + catch { + return false + } +} + // Helper function to delete multiple messages from the queue in a single batch async function delete_queue_message_batch(c: Context, db: ReturnType, queueName: string, msgIds: number[]) { try { @@ -828,9 +910,10 @@ app.post('/sync', async (c) => { cloudlog({ requestId: c.get('requestId'), message: `[Sync Request] Received trigger to process queue.` }) // Require JSON body with queue_name and optional batch_size - const body = await parseBody<{ queue_name: string, batch_size?: number }>(c) + const body = await parseBody<{ queue_name: string, batch_size?: number, healthcheck_url?: string | null }>(c) const queueName = body?.queue_name const batchSize = body?.batch_size + const healthcheckUrl = typeof body?.healthcheck_url === 'string' ? body.healthcheck_url : null if (!queueName || typeof queueName !== 'string') { throw simpleError('missing_or_invalid_queue_name', 'Missing or invalid queue_name in body', { body }) @@ -860,8 +943,15 @@ app.post('/sync', async (c) => { let db: ReturnType | null = null try { db = getPgClient(c) - await processQueue(c, db, queueName, finalBatchSize) - cloudlog({ requestId: c.get('requestId'), message: `[Background Queue Sync] Background execution finished successfully.` }) + const result = await processQueue(c, db, queueName, finalBatchSize) + await maybePingCronHealthcheck(db, queueName, result, healthcheckUrl) + cloudlog({ + requestId: c.get('requestId'), + message: result.success + ? `[Background Queue Sync] Background execution finished successfully.` + : `[Background Queue Sync] Background execution finished with queue failures.`, + result, + }) } finally { if (db) @@ -880,6 +970,7 @@ export const __queueConsumerTestUtils__ = { getQueueBatchSize, getQueueHttpConcurrency, httpExceptionToQueueResponse, + maybePingCronHealthcheck, queueFailureResponse, resolveFunctionUrl, sanitizeDiscordResponseBody, diff --git a/supabase/migrations/20260521210531_cron_hyperping_healthchecks.sql b/supabase/migrations/20260521210531_cron_hyperping_healthchecks.sql new file mode 100644 index 0000000000..804592f12e --- /dev/null +++ b/supabase/migrations/20260521210531_cron_hyperping_healthchecks.sql @@ -0,0 +1,216 @@ +ALTER TABLE public.cron_tasks +ADD COLUMN IF NOT EXISTS healthcheck_url text; + +CREATE OR REPLACE FUNCTION public.process_queue_with_healthcheck( + queue_names text [], + batch_size integer, + healthcheck_url text +) +RETURNS void +LANGUAGE plpgsql +SET search_path = '' +AS $$ +DECLARE + calls_needed int; + headers jsonb; + queue_name text; + queue_size bigint; + url text; +BEGIN + IF batch_size IS NULL OR batch_size <= 0 THEN + RAISE EXCEPTION 'batch_size must be positive'; + END IF; + + headers := pg_catalog.jsonb_build_object( + 'Content-Type', 'application/json', + 'apisecret', public.get_apikey() + ); + url := public.get_db_url() || '/functions/v1/triggers/queue_consumer/sync'; + + FOREACH queue_name IN ARRAY queue_names LOOP + BEGIN + EXECUTE pg_catalog.format('SELECT count(*) FROM pgmq.%I', 'q_' || queue_name) + INTO queue_size; + + IF queue_size > 0 THEN + calls_needed := LEAST( + pg_catalog.ceil(queue_size / batch_size::double precision)::int, + 10 + ); + ELSE + calls_needed := 1; + END IF; + + FOR i IN 1..calls_needed LOOP + PERFORM net.http_post( + url := url, + headers := headers, + body := pg_catalog.jsonb_strip_nulls(pg_catalog.jsonb_build_object( + 'queue_name', queue_name, + 'batch_size', batch_size, + 'healthcheck_url', healthcheck_url + )), + timeout_milliseconds := 8000 + ); + END LOOP; + EXCEPTION WHEN OTHERS THEN + RAISE WARNING 'process_queue_with_healthcheck failed for queue "%": %', queue_name, SQLERRM; + END; + END LOOP; +END; +$$; + +ALTER FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) OWNER TO postgres; + +REVOKE ALL ON FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) FROM public; +REVOKE ALL ON FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) FROM anon; +REVOKE ALL ON FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) FROM authenticated; +REVOKE ALL ON FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) FROM service_role; +GRANT EXECUTE ON FUNCTION public.process_queue_with_healthcheck( + text [], integer, text +) TO service_role; + +CREATE OR REPLACE FUNCTION public.process_all_cron_tasks() +RETURNS void +LANGUAGE plpgsql +SET search_path = '' +AS $$ +DECLARE + current_hour int; + current_minute int; + current_second int; + current_dow int; + current_day int; + task RECORD; + queue_names text[]; + should_run boolean; + lock_acquired boolean; +BEGIN + -- Try to acquire an advisory lock (non-blocking) + -- Lock ID 1 is reserved for process_all_cron_tasks + -- pg_try_advisory_lock returns true if lock acquired, false if already held + lock_acquired := pg_try_advisory_lock(1); + + IF NOT lock_acquired THEN + -- Another instance is already running, skip this execution + RAISE NOTICE 'process_all_cron_tasks: skipped, another instance is already running'; + RETURN; + END IF; + + -- Wrap everything in a block so we can ensure the lock is released + BEGIN + -- Get current time components in UTC + current_hour := EXTRACT(HOUR FROM NOW()); + current_minute := EXTRACT(MINUTE FROM NOW()); + current_second := EXTRACT(SECOND FROM NOW()); + current_dow := EXTRACT(DOW FROM NOW()); + current_day := EXTRACT(DAY FROM NOW()); + + -- Loop through all enabled tasks + FOR task IN SELECT * FROM public.cron_tasks WHERE enabled = true LOOP + should_run := false; + + -- Check if task should run based on its schedule + IF task.second_interval IS NOT NULL THEN + -- Run every N seconds + -- Since pg_cron interval is not clock-aligned, we run on every invocation + -- for second_interval tasks (the cron job itself runs every 10 seconds) + should_run := true; + ELSIF task.minute_interval IS NOT NULL THEN + -- Run every N minutes + -- Use current_second < 10 to catch first run of each minute (works with any cron offset) + should_run := (current_minute % task.minute_interval = 0) + AND (current_second < 10); + ELSIF task.hour_interval IS NOT NULL THEN + -- Run every N hours at specific minute + -- Use current_second < 10 to catch first run + should_run := (current_hour % task.hour_interval = 0) + AND (current_minute = COALESCE(task.run_at_minute, 0)) + AND (current_second < 10); + ELSIF task.run_at_hour IS NOT NULL THEN + -- Run at specific time + -- Use current_second < 10 to catch first run + should_run := (current_hour = task.run_at_hour) + AND (current_minute = COALESCE(task.run_at_minute, 0)) + AND (current_second < 10); + + -- Check day of week constraint + IF should_run AND task.run_on_dow IS NOT NULL THEN + should_run := (current_dow = task.run_on_dow); + END IF; + + -- Check day of month constraint + IF should_run AND task.run_on_day IS NOT NULL THEN + should_run := (current_day = task.run_on_day); + END IF; + END IF; + + -- Execute the task if it should run + IF should_run THEN + BEGIN + CASE task.task_type + WHEN 'function' THEN + EXECUTE 'SELECT ' || task.target; + + WHEN 'queue' THEN + PERFORM pgmq.send( + task.target, + COALESCE(task.payload, jsonb_build_object('function_name', task.target)) + ); + + WHEN 'function_queue' THEN + -- Parse JSON array of queue names + SELECT array_agg(value::text) INTO queue_names + FROM jsonb_array_elements_text(task.target::jsonb); + + IF task.healthcheck_url IS NOT NULL THEN + PERFORM public.process_queue_with_healthcheck( + COALESCE(queue_names, ARRAY[]::text[]), + COALESCE(task.batch_size, 950), + task.healthcheck_url + ); + ELSIF task.batch_size IS NOT NULL THEN + PERFORM public.process_function_queue(queue_names, task.batch_size); + ELSE + PERFORM public.process_function_queue(queue_names); + END IF; + END CASE; + EXCEPTION WHEN OTHERS THEN + RAISE WARNING 'cron task "%" failed: %', task.name, SQLERRM; + END; + END IF; + END LOOP; + + EXCEPTION WHEN OTHERS THEN + -- Release the lock even if an error occurred + PERFORM pg_advisory_unlock(1); + RAISE; + END; + + -- Release the advisory lock + PERFORM pg_advisory_unlock(1); +END; +$$; + +ALTER FUNCTION public.process_all_cron_tasks() OWNER TO postgres; + +REVOKE ALL ON FUNCTION public.process_all_cron_tasks() FROM public; +REVOKE ALL ON FUNCTION public.process_all_cron_tasks() FROM anon; +REVOKE ALL ON FUNCTION public.process_all_cron_tasks() FROM authenticated; +REVOKE ALL ON FUNCTION public.process_all_cron_tasks() FROM service_role; +GRANT EXECUTE ON FUNCTION public.process_all_cron_tasks() TO service_role; + +COMMENT ON FUNCTION public.process_all_cron_tasks() IS +$$Consolidated cron task processor that runs every 10 seconds. Uses advisory +lock (ID=1) to prevent concurrent execution - if a previous run is still +executing, the new invocation will skip.$$; diff --git a/tests/cron-healthchecks.test.ts b/tests/cron-healthchecks.test.ts new file mode 100644 index 0000000000..7d6dbcc35b --- /dev/null +++ b/tests/cron-healthchecks.test.ts @@ -0,0 +1,69 @@ +import type { Pool, PoolClient } from 'pg' +import { randomUUID } from 'node:crypto' +import { Pool as PgPool } from 'pg' +import { afterAll, beforeAll, describe, expect, it } from 'vitest' +import { POSTGRES_URL } from './test-utils.ts' + +async function rollbackAndRelease(client: PoolClient) { + try { + await client.query('ROLLBACK') + } + finally { + client.release() + } +} + +describe('cron healthchecks', () => { + let pool: Pool + + beforeAll(() => { + pool = new PgPool({ + connectionString: POSTGRES_URL, + max: 2, + idleTimeoutMillis: 2000, + }) + }) + + afterAll(async () => { + await pool.end() + }) + + it.concurrent('stores any healthcheck URL on cron tasks', async () => { + const client = await pool.connect() + + try { + await client.query('BEGIN') + const taskName = `test_healthcheck_url_${randomUUID()}` + + const result = await client.query<{ healthcheck_url: string }>( + ` + INSERT INTO public.cron_tasks ( + name, + description, + task_type, + target, + second_interval, + healthcheck_url, + enabled + ) + VALUES ( + $1, + 'Healthcheck URL test', + 'function'::public.cron_task_type, + 'pg_catalog.pg_sleep(0)', + 10, + 'https://example.com/healthcheck', + true + ) + RETURNING healthcheck_url + `, + [taskName], + ) + + expect(result.rows).toEqual([{ healthcheck_url: 'https://example.com/healthcheck' }]) + } + finally { + await rollbackAndRelease(client) + } + }) +}) diff --git a/tests/queue-consumer-message-shape.unit.test.ts b/tests/queue-consumer-message-shape.unit.test.ts index 79cdc13ff2..352990b8b4 100644 --- a/tests/queue-consumer-message-shape.unit.test.ts +++ b/tests/queue-consumer-message-shape.unit.test.ts @@ -1,8 +1,18 @@ import { HTTPException } from 'hono/http-exception' -import { describe, expect, it } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { __queueConsumerTestUtils__, MAX_QUEUE_READS, messagesArraySchema } from '../supabase/functions/_backend/triggers/queue_consumer.ts' import { parseSchema } from '../supabase/functions/_backend/utils/ark_validation.ts' +function createHealthcheckDb(queueLength: number) { + return { + db: { + query: vi.fn(async >(): Promise<{ rows: T[] }> => ({ + rows: [{ queue_length: String(queueLength) } as T], + })), + }, + } +} + describe('queue_consumer legacy message compatibility', () => { it.concurrent('uses the payload envelope when it is present', () => { const [message] = parseSchema(messagesArraySchema, [ @@ -203,4 +213,103 @@ describe('queue_consumer legacy message compatibility', () => { expect(details.bodyPreview).toContain('"id":123') expect(details.bodyPreview).toContain('"queueName":"on_manifest_create"') }) + + it.concurrent('calls the healthcheck URL when the worker succeeds and the queue is empty', async () => { + const { db } = createHealthcheckDb(0) + const fetchImpl = vi.fn(async () => new Response(null, { status: 200 })) as unknown as typeof fetch + + const reported = await __queueConsumerTestUtils__.maybePingCronHealthcheck( + db as never, + 'cron_email', + { + archivedCount: 0, + failedCount: 0, + processedCount: 1, + readSucceeded: true, + skippedCount: 0, + success: true, + successCount: 1, + }, + 'https://example.com/healthcheck', + fetchImpl, + ) + + expect(reported).toBe(true) + expect(fetchImpl).toHaveBeenCalledTimes(1) + expect(fetchImpl).toHaveBeenCalledWith('https://example.com/healthcheck', expect.objectContaining({ + method: 'GET', + })) + }) + + it.concurrent('does not call the healthcheck URL when queue work remains', async () => { + const { db } = createHealthcheckDb(2) + const fetchImpl = vi.fn(async () => new Response(null, { status: 200 })) as unknown as typeof fetch + + const reported = await __queueConsumerTestUtils__.maybePingCronHealthcheck( + db as never, + 'cron_email', + { + archivedCount: 0, + failedCount: 0, + processedCount: 1, + readSucceeded: true, + skippedCount: 0, + success: true, + successCount: 1, + }, + 'https://example.com/healthcheck', + fetchImpl, + ) + + expect(reported).toBe(false) + expect(fetchImpl).not.toHaveBeenCalled() + }) + + it.concurrent('returns false when the healthcheck URL responds with an error', async () => { + const { db } = createHealthcheckDb(0) + const fetchImpl = vi.fn(async () => new Response(null, { status: 500 })) as unknown as typeof fetch + + const reported = await __queueConsumerTestUtils__.maybePingCronHealthcheck( + db as never, + 'cron_email', + { + archivedCount: 0, + failedCount: 0, + processedCount: 1, + readSucceeded: true, + skippedCount: 0, + success: true, + successCount: 1, + }, + 'https://example.com/healthcheck', + fetchImpl, + ) + + expect(reported).toBe(false) + expect(fetchImpl).toHaveBeenCalledTimes(1) + }) + + it.concurrent('does not call the healthcheck URL when the worker failed', async () => { + const { db } = createHealthcheckDb(0) + const fetchImpl = vi.fn(async () => new Response(null, { status: 200 })) as unknown as typeof fetch + + const reported = await __queueConsumerTestUtils__.maybePingCronHealthcheck( + db as never, + 'cron_email', + { + archivedCount: 0, + failedCount: 1, + processedCount: 1, + readSucceeded: true, + skippedCount: 0, + success: false, + successCount: 0, + }, + 'https://example.com/healthcheck', + fetchImpl, + ) + + expect(reported).toBe(false) + expect(fetchImpl).not.toHaveBeenCalled() + }) }) diff --git a/tests/security-definer-execute-hardening.test.ts b/tests/security-definer-execute-hardening.test.ts index 513decfd5a..a81e34f795 100644 --- a/tests/security-definer-execute-hardening.test.ts +++ b/tests/security-definer-execute-hardening.test.ts @@ -31,6 +31,8 @@ const SERVICE_ONLY_PROCS = [ 'public.noupdate()', 'public.prevent_last_super_admin_binding_delete()', 'public.prevent_last_super_admin_binding_update()', + 'public.process_all_cron_tasks()', + 'public.process_queue_with_healthcheck(text[], integer, text)', 'public.reassign_webhook_created_by_before_user_delete()', 'public.resync_org_user_role_bindings(uuid, uuid)', 'public.sanitize_apps_text_fields()',