From b70cceadd2b16ffad025013fb57cf6d9d5053274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= <31549762+mrl5@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:58:16 +0200 Subject: [PATCH] feat: allow bgworkers to access JWT claims (#11) related to https://github.com/neondatabase/cloud/issues/16041 and this checkpoint from https://github.com/neondatabase/pg_session_jwt/issues/6 > Are the thread local variables supposed to be per-backend? Doesn't seem like they need to be thread local to me. I think this has been previously discussed. closes #6 --------- Co-authored-by: Conrad Ludgate --- .gitignore | 1 + CONTRIBUTING.md | 22 +- Cargo.lock | 118 +++- Cargo.toml | 19 +- README.md | 18 +- pg_session_jwt.control | 2 +- pgrx-tests/.gitignore | 7 + pgrx-tests/Cargo.toml | 72 +++ pgrx-tests/LICENSE | 26 + pgrx-tests/README.md | 9 + pgrx-tests/src/framework.rs | 886 +++++++++++++++++++++++++++ pgrx-tests/src/framework/shutdown.rs | 130 ++++ pgrx-tests/src/lib.rs | 11 + src/gucs.rs | 29 + src/lib.rs | 301 +++------ tests/pg_session_jwt.rs | 202 ++++++ 16 files changed, 1632 insertions(+), 221 deletions(-) create mode 100644 pgrx-tests/.gitignore create mode 100644 pgrx-tests/Cargo.toml create mode 100644 pgrx-tests/LICENSE create mode 100644 pgrx-tests/README.md create mode 100644 pgrx-tests/src/framework.rs create mode 100644 pgrx-tests/src/framework/shutdown.rs create mode 100644 pgrx-tests/src/lib.rs create mode 100644 src/gucs.rs create mode 100644 tests/pg_session_jwt.rs diff --git a/.gitignore b/.gitignore index 3ea57d4..c619c96 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /target *.iml **/*.rs.bk +*.swp diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81e578f..acf0b7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,8 +15,15 @@ Let's initialize pgrx. cargo pgrx init ``` -It's time to run `pg_session_jwt` locally with +## How to run the extension locally + +It's time to run `pg_session_jwt` locally. Please note that `neon.auth.jwk` +parameter MUST be set when new connection is created (for more details please +refer to the README file). ```console +MY_JWK=... +export PGOPTIONS="-c neon.auth.jwk=$MY_JWK" + cargo pgrx run pg16 ``` @@ -35,3 +42,16 @@ If you introduce new function make sure to reload the extension with DROP EXTENSION pg_session_jwt; CREATE EXTENSION pg_session_jwt; ``` + +## Before sending a PR + +You can lint your code with +```console +rustfmt src/*.rs tests/*.rs +cargo clippy --fix --allow-staged +``` + +You can run test-suite +```console +cargo test +``` diff --git a/Cargo.lock b/Cargo.lock index 4dbc791..6a1c5ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,12 +26,55 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.6.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +[[package]] +name = "anstyle-parse" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "anyhow" version = "1.0.89" @@ -249,8 +292,10 @@ version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", ] [[package]] @@ -271,6 +316,12 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +[[package]] +name = "colorchoice" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" + [[package]] name = "const-oid" version = "0.9.6" @@ -481,6 +532,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "escape8259" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" + [[package]] name = "eyre" version = "0.6.12" @@ -721,6 +778,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.12.1" @@ -823,6 +886,18 @@ dependencies = [ "libc", ] +[[package]] +name = "libtest-mimic" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +dependencies = [ + "anstream", + "anstyle", + "clap", + "escape8259", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1028,14 +1103,17 @@ dependencies = [ [[package]] name = "pg_session_jwt" -version = "0.0.1" +version = "0.1.0" dependencies = [ "base64ct", + "eyre", "heapless", "jose-jwk", + "libtest-mimic", "p256", "pgrx", "pgrx-tests", + "postgres", "rand", "serde", "serde_json", @@ -1138,8 +1216,6 @@ dependencies = [ [[package]] name = "pgrx-tests" version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3abc01e2bb930b072bd660d04c8eaa69a29d4727d5b2a641f946c603c1605e" dependencies = [ "clap-cargo", "eyre", @@ -1157,6 +1233,7 @@ dependencies = [ "serde_json", "sysinfo", "thiserror", + "trybuild", ] [[package]] @@ -1717,6 +1794,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -1779,6 +1862,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -1902,6 +1994,20 @@ dependencies = [ "winnow", ] +[[package]] +name = "trybuild" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "207aa50d36c4be8d8c6ea829478be44a372c6a77669937bb39c698e52f1491e8" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "termcolor", + "toml", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1970,6 +2076,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 102a80e..72040e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,9 @@ +[workspace] +members = ["pgrx-tests"] + [package] name = "pg_session_jwt" -version = "0.0.1" +version = "0.1.0" edition = "2021" [lib] @@ -11,7 +14,7 @@ default = ["pg16"] pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg16 = ["pgrx/pg16", "pgrx-tests/pg16" ] -pg_test = ["dep:rand", "base64ct/alloc"] +pg_test = [] [dependencies] base64ct = { version = "1.6.0", features = ["std"] } @@ -22,10 +25,11 @@ pgrx = "=0.11.3" serde = { version = "1.0.203", features = ["derive"], default-features = false } serde_json = { version = "1.0.117", default-features = false } -rand = { version = "0.8", optional = true } - [dev-dependencies] -pgrx-tests = "=0.11.3" +eyre = "0.6.12" +libtest-mimic = "0.8.1" +pgrx-tests = { path = "./pgrx-tests" } +postgres = "0.19.9" rand = "0.8" [profile.dev] @@ -36,3 +40,8 @@ panic = "unwind" opt-level = 3 lto = "fat" codegen-units = 1 + +[[test]] +name = "tests" +harness = false +path = "tests/pg_session_jwt.rs" diff --git a/README.md b/README.md index 6d74a9c..c00c66e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ pg\_session\_jwt ================ -`pg_session_jwt` is a PostgreSQL extension designed to handle authenticated sessions through a JWT. This JWT is then verified against a JWK (JSON Web Key) to ensure its authenticity. Both the JWK and the JWT must be provided to the extension by a Postgres superuser. The extension then stores the JWT in the database for later retrieval, and exposes functions to retrieve the user ID (the `sub` subject field) and other parts of the payload. +`pg_session_jwt` is a PostgreSQL extension designed to handle authenticated sessions through a JWT. This JWT is then verified against a JWK (JSON Web Key) to ensure its authenticity. + +**JWK can only be set at postmaster startup, from the configuration file, or by client request in the connection startup packet** (e.g., from libpq's PGOPTIONS variable), whereas the JWT can be set anytime at runtime. The extension then stores the JWT in the database for later retrieval, and exposes functions to retrieve the user ID (the `sub` subject field) and other parts of the payload. The goal of this extension is to provide a secure and efficient way to manage authenticated sessions in a PostgreSQL database. The JWTs can be generated by third-party auth providers, and then developers can leverage the JWT for [Row Level Security](https://www.postgresql.org/docs/current/ddl-rowsecurity.html) (RLS) policies, or to retrieve the user ID for other purposes (column defaults, filters, etc.). @@ -20,15 +22,23 @@ Features Usage ----- +Before calling functions make sure that `neon.auth.jwk` parameter is properly initialized. [libpq connect options](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS) can be used for that. + +For example: +```console +MY_JWK=... +export PGOPTIONS="-c neon.auth.jwk=$MY_JWK" +``` + `pg_session_jwt` exposes four main functions: -### 1\. auth.init(kid bigint, jwk jsonb) → void +### 1\. auth.init() → void -Initializes a session with a given key identifier (KID) and JWK data in JSONB format. +Initializes a session using JWK stored in `neon.auth.jwk` [run-time parameter](https://www.postgresql.org/docs/current/sql-show.html). Please remember that this parameter is fixed for a given connection once it's started (but it can vary across different connections) ### 2\. auth.jwt\_session\_init(jwt text) → void -Initializes the JWT session with the provided `jwt` as a string. +Initializes the JWT session with the provided `jwt` as a string. JWT must be signed by the JWK that was initialized with `auth.init()` ### 3\. auth.session(s text) → jsonb diff --git a/pg_session_jwt.control b/pg_session_jwt.control index 3745faa..825e99c 100644 --- a/pg_session_jwt.control +++ b/pg_session_jwt.control @@ -2,4 +2,4 @@ comment = 'pg_session_jwt: manage authentication sessions using JWTs' default_version = '@CARGO_VERSION@' module_pathname = '$libdir/pg_session_jwt' relocatable = false -superuser = true +superuser = false diff --git a/pgrx-tests/.gitignore b/pgrx-tests/.gitignore new file mode 100644 index 0000000..ab3ae30 --- /dev/null +++ b/pgrx-tests/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +.idea/ +target/ +*.iml +**/*.rs.bk +Cargo.lock +sql/pgrx_tests-1.0.sql diff --git a/pgrx-tests/Cargo.toml b/pgrx-tests/Cargo.toml new file mode 100644 index 0000000..78c7cd3 --- /dev/null +++ b/pgrx-tests/Cargo.toml @@ -0,0 +1,72 @@ +#LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +#LICENSE +#LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +#LICENSE +#LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +#LICENSE +#LICENSE Portions Copyright 2024-2024 Neon, Inc. +#LICENSE +#LICENSE All rights reserved. +#LICENSE +#LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +[package] +name = "pgrx-tests" +version = "0.11.3" +authors = ["PgCentral Foundation, Inc. "] +license = "MIT" +description = "Test framework for 'pgrx'-based Postgres extensions" +homepage = "https://github.com/pgcentralfoundation/pgrx/" +repository = "https://github.com/pgcentralfoundation/pgrx/" +documentation = "https://docs.rs/pgrx-tests" +readme = "README.md" +edition = "2021" + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +default = ["proptest"] +pg11 = ["pgrx/pg11"] +pg12 = ["pgrx/pg12"] +pg13 = ["pgrx/pg13"] +pg14 = ["pgrx/pg14"] +pg15 = ["pgrx/pg15"] +pg16 = ["pgrx/pg16"] +pg_test = [] +proptest = ["dep:proptest"] +cshim = ["pgrx/cshim"] +no-schema-generation = [ + "pgrx/no-schema-generation", + "pgrx-macros/no-schema-generation", +] + +[package.metadata.docs.rs] +features = ["pg14", "proptest"] +no-default-features = true +targets = ["x86_64-unknown-linux-gnu"] +# Enable `#[cfg(docsrs)]` (https://docs.rs/about/builds#cross-compiling) +rustc-args = ["--cfg", "docsrs"] +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +clap-cargo = "0.11.0" +owo-colors = "3.5" +once_cell = "1.18.0" +libc = "0.2.149" +pgrx = "=0.11.3" +pgrx-macros = "=0.11.3" +pgrx-pg-config = "=0.11.3" +postgres = "0.19.7" +proptest = { version = "1", optional = true } +regex = "1.10.0" +serde = "1.0" +serde_json = "1.0" +sysinfo = "0.29.10" +eyre = "0.6.8" +thiserror = "1.0" +rand = "0.8.5" + +[dev-dependencies] +eyre = "0.6.8" # testing functions that return `eyre::Result` +trybuild = "1" diff --git a/pgrx-tests/LICENSE b/pgrx-tests/LICENSE new file mode 100644 index 0000000..03632a3 --- /dev/null +++ b/pgrx-tests/LICENSE @@ -0,0 +1,26 @@ +MIT License + +Portions Copyright 2019-2021 ZomboDB, LLC. +Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +Portions Copyright 2023 PgCentral Foundation, Inc. +Portions Copyright 2024 Neon, Inc. + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pgrx-tests/README.md b/pgrx-tests/README.md new file mode 100644 index 0000000..befa98b --- /dev/null +++ b/pgrx-tests/README.md @@ -0,0 +1,9 @@ +# pgrx-tests + +Test framework for [`pgrx`](https://crates.io/crates/pgrx/). + +Meant to be used as one of your `[dev-dependencies]` when using `pgrx`. + +Forked off of pgrx 0.11.3 by Conrad Ludgate for the purposes of adding support for +1. Providing options used for initialising GucContext::Backend +2. Running tests as non-superuser diff --git a/pgrx-tests/src/framework.rs b/pgrx-tests/src/framework.rs new file mode 100644 index 0000000..cfc1ad6 --- /dev/null +++ b/pgrx-tests/src/framework.rs @@ -0,0 +1,886 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE Portions Copyright 2024-2024 Neon, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use std::collections::HashSet; +use std::process::{Command, Stdio}; + +use eyre::{eyre, WrapErr}; +use once_cell::sync::Lazy; +use owo_colors::OwoColorize; +use pgrx::prelude::*; +use pgrx_pg_config::{ + cargo::PgrxManifestExt, createdb, get_c_locale_flags, get_target_dir, PgConfig, Pgrx, +}; +use postgres::error::DbError; +use std::collections::HashMap; +use std::io::{BufRead, BufReader, Write}; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use sysinfo::{Pid, ProcessExt, System, SystemExt}; + +mod shutdown; +use shutdown::add_shutdown_hook; + +type LogLines = Arc>>>; + +struct SetupState { + installed: bool, + loglines: LogLines, + system_session_id: String, +} + +static TEST_MUTEX: Lazy> = Lazy::new(|| { + Mutex::new(SetupState { + installed: false, + loglines: Arc::new(Mutex::new(HashMap::new())), + system_session_id: "NONE".to_string(), + }) +}); + +// The goal of this closure is to allow "wrapping" of anything that might issue +// an SQL simple_query or query using either a postgres::Client or +// postgres::Transaction and capture the output. The use of this wrapper is +// completely optional, but it might help narrow down some errors later on. +fn query_wrapper( + query: Option, + query_params: Option<&[&(dyn postgres::types::ToSql + Sync)]>, + mut f: F, +) -> eyre::Result +where + T: IntoIterator, + F: FnMut( + Option, + Option<&[&(dyn postgres::types::ToSql + Sync)]>, + ) -> Result, +{ + let result = f(query.clone(), query_params.clone()); + + match result { + Ok(result) => Ok(result), + Err(e) => { + if let Some(dberror) = e.as_db_error() { + let query = query.unwrap(); + let query_message = dberror.message(); + + let code = dberror.code().code(); + let severity = dberror.severity(); + + let mut message = format!("{} SQLSTATE[{}]", severity, code) + .bold() + .red() + .to_string(); + + message.push_str(format!(": {}", query_message.bold().white()).as_str()); + message.push_str(format!("\nquery: {}", query.bold().white()).as_str()); + message.push_str( + format!( + "\nparams: {}", + match query_params { + Some(params) => format!("{:?}", params), + None => "None".to_string(), + } + ) + .as_str(), + ); + + if let Ok(var) = std::env::var("RUST_BACKTRACE") { + if var.eq("1") { + let detail = dberror.detail().unwrap_or("None"); + let hint = dberror.hint().unwrap_or("None"); + let schema = dberror.hint().unwrap_or("None"); + let table = dberror.table().unwrap_or("None"); + let more_info = format!( + "\ndetail: {detail}\nhint: {hint}\nschema: {schema}\ntable: {table}" + ); + message.push_str(more_info.as_str()); + } + } + + Err(eyre!(message)) + } else { + return Err(e).wrap_err("non-DbError"); + } + } + } +} + +pub fn run_test( + options: Option<&str>, + expected_error: Option<&str>, + postgresql_conf: Vec<&'static str>, + queries: impl for<'a> FnOnce(&'a mut postgres::Client) -> Result<(), postgres::Error>, +) -> eyre::Result<()> { + if std::env::var_os("PGRX_TEST_SKIP").unwrap_or_default() != "" { + eprintln!("Skipping test because `PGRX_TEST_SKIP` is set in the environment",); + return Ok(()); + } + let (loglines, system_session_id) = initialize_test_framework(postgresql_conf)?; + + { + let (mut client, _) = client(None, &get_pg_user())?; + + let resp = client + .query_opt("SELECT rolname FROM pg_roles WHERE rolname = 'pgrx'", &[]) + .unwrap(); + + if resp.is_none() { + client + .execute("CREATE ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + .unwrap(); + } else { + client + .execute("ALTER ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + .unwrap(); + } + + client + .execute("GRANT USAGE ON SCHEMA auth TO pgrx", &[]) + .unwrap(); + } + + let (mut client, session_id) = client(options, "pgrx")?; + let result = queries(&mut client); + + if let Err(e) = result { + let error_as_string = format!("{e}"); + let cause = e.into_source(); + + let (pg_location, rust_location, message) = + if let Some(Some(dberror)) = cause.map(|e| e.downcast_ref::().cloned()) { + let received_error_message = dberror.message(); + + if Some(received_error_message) == expected_error { + // the error received is the one we expected, so just return if they match + return Ok(()); + } + + let pg_location = dberror.file().unwrap_or("").to_string(); + let rust_location = dberror.where_().unwrap_or("").to_string(); + + ( + pg_location, + rust_location, + received_error_message.to_string(), + ) + } else { + ( + "".to_string(), + "".to_string(), + format!("{error_as_string}"), + ) + }; + + // wait a second for Postgres to get log messages written to stderr + std::thread::sleep(std::time::Duration::from_millis(1000)); + + let system_loglines = format_loglines(&system_session_id, &loglines); + let session_loglines = format_loglines(&session_id, &loglines); + panic!( + "\n\nPostgres Messages:\n{system_loglines}\n\nTest Function Messages:\n{session_loglines}\n\nClient Error:\n{message}\npostgres location: {pg_location}\nrust location: {rust_location}\n\n", + system_loglines = system_loglines.dimmed().white(), + session_loglines = session_loglines.cyan(), + message = message.bold().red(), + pg_location = pg_location.dimmed().white(), + rust_location = rust_location.yellow() + ); + } else if let Some(message) = expected_error { + // we expected an ERROR, but didn't get one + return Err(eyre!("Expected error: {message}")); + } else { + Ok(()) + } +} + +fn format_loglines(session_id: &str, loglines: &LogLines) -> String { + let mut result = String::new(); + + for line in loglines + .lock() + .unwrap() + .entry(session_id.to_string()) + .or_default() + .iter() + { + result.push_str(line); + result.push('\n'); + } + + result +} + +fn initialize_test_framework( + postgresql_conf: Vec<&'static str>, +) -> eyre::Result<(LogLines, String)> { + let mut state = TEST_MUTEX.lock().unwrap_or_else(|_| { + // This used to immediately throw an std::process::exit(1), but it + // would consume both stdout and stderr, resulting in error messages + // not being displayed unless you were running tests with --nocapture. + panic!( + "Could not obtain test mutex. A previous test may have hard-aborted while holding it." + ); + }); + + if !state.installed { + shutdown::register_shutdown_hook(); + install_extension()?; + initdb(postgresql_conf)?; + + let system_session_id = start_pg(state.loglines.clone())?; + let pg_config = get_pg_config()?; + dropdb()?; + createdb(&pg_config, get_pg_dbname(), true, false)?; + create_extension()?; + state.installed = true; + state.system_session_id = system_session_id; + } + + Ok((state.loglines.clone(), state.system_session_id.clone())) +} + +fn get_pg_config() -> eyre::Result { + let pgrx = Pgrx::from_config().wrap_err("Unable to get PGRX from config")?; + + let pg_version = pg_sys::get_pg_major_version_num(); + + let pg_config = pgrx + .get(&format!("pg{}", pg_version)) + .wrap_err_with(|| { + format!( + "Error getting pg_config: {} is not a valid postgres version", + pg_version + ) + }) + .unwrap() + .clone(); + + Ok(pg_config) +} + +fn client(options: Option<&str>, user: &str) -> eyre::Result<(postgres::Client, String)> { + let pg_config = get_pg_config()?; + + let mut config = postgres::Config::new(); + + config + .host(pg_config.host()) + .port( + pg_config + .test_port() + .expect("unable to determine test port"), + ) + .user(user) + .dbname(&get_pg_dbname()); + + if let Some(options) = options { + config.options(options); + } + + let mut client = config + .connect(postgres::NoTls) + .wrap_err("Error connecting to Postgres")?; + + let sid_query_result = query_wrapper( + Some("SELECT to_hex(trunc(EXTRACT(EPOCH FROM backend_start))::integer) || '.' || to_hex(pid) AS sid FROM pg_stat_activity WHERE pid = pg_backend_pid();".to_string()), + Some(&[]), + |query, query_params| client.query(&query.unwrap(), query_params.unwrap()), + ) + .wrap_err("There was an issue attempting to get the session ID from Postgres")?; + + let session_id = match sid_query_result.get(0) { + Some(row) => row.get::<&str, &str>("sid").to_string(), + None => Err(eyre!("Failed to obtain a client Session ID from Postgres"))?, + }; + + if user != "pgrx" { + query_wrapper( + Some("SET log_min_messages TO 'INFO';".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_min_messages TO 'INFO'")?; + + query_wrapper( + Some("SET log_min_duration_statement TO 1000;".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_min_duration_statement TO 1000;")?; + + query_wrapper( + Some("SET log_statement TO 'all';".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_statement TO 'all';")?; + } + + Ok((client, session_id)) +} + +fn install_extension() -> eyre::Result<()> { + eprintln!("installing extension"); + let profile = std::env::var("PGRX_BUILD_PROFILE").unwrap_or("debug".into()); + let no_schema = std::env::var("PGRX_NO_SCHEMA").unwrap_or("false".into()) == "true"; + let mut features = std::env::var("PGRX_FEATURES") + .unwrap_or("".to_string()) + .split_ascii_whitespace() + .map(|s| s.to_string()) + .collect::>(); + features.insert("pg_test".into()); + + let no_default_features = + std::env::var("PGRX_NO_DEFAULT_FEATURES").unwrap_or("false".to_string()) == "true"; + let all_features = std::env::var("PGRX_ALL_FEATURES").unwrap_or("false".to_string()) == "true"; + + let pg_version = format!("pg{}", pg_sys::get_pg_major_version_string()); + let pgrx = Pgrx::from_config()?; + let pg_config = pgrx.get(&pg_version)?; + let cargo_test_args = get_cargo_test_features()?; + println!("detected cargo args: {:?}", cargo_test_args); + + features.extend(cargo_test_args.features.iter().cloned()); + + let mut command = cargo_pgrx(); + command + .arg("install") + .arg("--test") + .arg("--pg-config") + .arg(pg_config.path().ok_or(eyre!("No pg_config found"))?) + .stdout(Stdio::inherit()) + .stderr(Stdio::piped()) + .env("CARGO_TARGET_DIR", get_target_dir()?); + + if let Ok(manifest_path) = std::env::var("PGRX_MANIFEST_PATH") { + command.arg("--manifest-path"); + command.arg(manifest_path); + } + + if let Ok(rust_log) = std::env::var("RUST_LOG") { + command.env("RUST_LOG", rust_log); + } + + if !features.is_empty() { + command.arg("--features"); + command.arg(features.into_iter().collect::>().join(" ")); + } + + if no_default_features || cargo_test_args.no_default_features { + command.arg("--no-default-features"); + } + + if all_features || cargo_test_args.all_features { + command.arg("--all-features"); + } + + match profile.trim() { + // For legacy reasons, cargo has two names for the debug profile... (We + // also ignore the empty string here, just in case). + "debug" | "dev" | "" => {} + "release" => { + command.arg("--release"); + } + profile => { + command.args(["--profile", profile]); + } + } + + if no_schema { + command.arg("--no-schema"); + } + + let command_str = format!("{:?}", command); + + let child = command.spawn().wrap_err_with(|| { + format!( + "Failed to spawn process for installing extension using command: '{}': ", + command_str + ) + })?; + + let output = child.wait_with_output().wrap_err_with(|| { + format!( + "Failed waiting for spawned process attempting to install extension using command: '{}': ", + command_str + ) + })?; + + if !output.status.success() { + return Err(eyre!( + "Failure installing extension using command: {}\n\n{}{}", + command_str, + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap() + )); + } + + Ok(()) +} + +fn initdb(postgresql_conf: Vec<&'static str>) -> eyre::Result<()> { + let pgdata = get_pgdata_path()?; + + if !pgdata.is_dir() { + let pg_config = get_pg_config()?; + let mut command = Command::new( + pg_config + .initdb_path() + .wrap_err("unable to determine initdb path")?, + ); + + command + .args(get_c_locale_flags()) + .arg("-D") + .arg(pgdata.to_str().unwrap()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + + let command_str = format!("{:?}", command); + + let child = command.spawn().wrap_err_with(|| { + format!( + "Failed to spawn process for initializing database using command: '{}': ", + command_str + ) + })?; + + let output = child.wait_with_output().wrap_err_with(|| { + format!( + "Failed waiting for spawned process attempting to initialize database using command: '{}': ", + command_str + ) + })?; + + if !output.status.success() { + return Err(eyre!( + "Failed to initialize database using command: {}\n\n{}{}", + command_str, + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap() + )); + } + } + + modify_postgresql_conf(pgdata, postgresql_conf) +} + +fn modify_postgresql_conf(pgdata: PathBuf, postgresql_conf: Vec<&'static str>) -> eyre::Result<()> { + let mut postgresql_conf_file = std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .open(format!("{}/postgresql.auto.conf", pgdata.display())) + .wrap_err("couldn't open postgresql.auto.conf")?; + postgresql_conf_file + .write_all("log_line_prefix='[%m] [%p] [%c]: '\n".as_bytes()) + .wrap_err("couldn't append log_line_prefix")?; + + for setting in postgresql_conf { + postgresql_conf_file + .write_all(format!("{setting}\n").as_bytes()) + .wrap_err("couldn't append custom setting to postgresql.conf")?; + } + + postgresql_conf_file + .write_all( + format!( + "unix_socket_directories = '{}'", + Pgrx::home().unwrap().display() + ) + .as_bytes(), + ) + .wrap_err("couldn't append `unix_socket_directories` setting to postgresql.conf")?; + Ok(()) +} + +fn start_pg(loglines: LogLines) -> eyre::Result { + wait_for_pidfile()?; + + let pg_config = get_pg_config()?; + let postmaster_path = pg_config + .postmaster_path() + .wrap_err("unable to determine postmaster path")?; + + let mut command = if use_valgrind() { + let mut cmd = Command::new("valgrind"); + cmd.args([ + "--leak-check=no", + "--gen-suppressions=all", + "--time-stamp=yes", + "--error-markers=VALGRINDERROR-BEGIN,VALGRINDERROR-END", + "--trace-children=yes", + ]); + // Try to provide a suppressions file, we'll likely get false positives + // if we can't, but that might be better than nothing. + if let Ok(path) = valgrind_suppressions_path(&pg_config) { + if path.exists() { + cmd.arg(format!("--suppressions={}", path.display())); + } + } + + cmd.arg(postmaster_path); + cmd + } else { + Command::new(postmaster_path) + }; + command + .arg("-D") + .arg(get_pgdata_path()?.to_str().unwrap()) + .arg("-h") + .arg(pg_config.host()) + .arg("-p") + .arg( + pg_config + .test_port() + .expect("unable to determine test port") + .to_string(), + ) + // Redirecting logs to files can hang the test framework, override it + .args([ + "-c", + "log_destination=stderr", + "-c", + "logging_collector=off", + ]) + .stdout(Stdio::inherit()) + .stderr(Stdio::piped()); + + let command_str = format!("{command:?}"); + + // start Postgres and monitor its stderr in the background + // also notify the main thread when it's ready to accept connections + let session_id = monitor_pg(command, command_str, loglines); + + Ok(session_id) +} + +fn valgrind_suppressions_path(pg_config: &PgConfig) -> Result { + let mut home = Pgrx::home()?; + home.push(pg_config.version()?); + home.push("src/tools/valgrind.supp"); + Ok(home) +} + +fn wait_for_pidfile() -> Result<(), eyre::Report> { + const MAX_PIDFILE_RETRIES: usize = 10; + + let pidfile = get_pid_file()?; + + let mut retries = 0; + while pidfile.exists() { + if retries > MAX_PIDFILE_RETRIES { + // break out and try to start postgres anyways, maybe it'll report a decent error about what's going on + eprintln!("`{}` has existed for ~10s. There might be some problem with the pgrx testing Postgres instance", pidfile.display()); + break; + } + eprintln!("`{}` still exists. Waiting...", pidfile.display()); + std::thread::sleep(Duration::from_secs(1)); + retries += 1; + } + Ok(()) +} + +fn monitor_pg(mut command: Command, cmd_string: String, loglines: LogLines) -> String { + let (sender, receiver) = std::sync::mpsc::channel(); + + std::thread::spawn(move || { + let mut child = command.spawn().expect("postmaster didn't spawn"); + + let pid = child.id(); + // Add a shutdown hook so we can terminate it when the test framework + // exits. TODO: Consider finding a way to handle cases where we fail to + // clean up due to a SIGNAL? + add_shutdown_hook(move || unsafe { + libc::kill(pid as libc::pid_t, libc::SIGTERM); + let message_string = std::ffi::CString::new( + format!("stopping postgres (pid={pid})\n") + .bold() + .blue() + .to_string(), + ) + .unwrap(); + // IMPORTANT: Rust string literals are not naturally null-terminated + libc::printf("%s\0".as_ptr().cast(), message_string.as_ptr()); + }); + + eprintln!( + "{cmd}\npid={p}", + cmd = cmd_string.bold().blue(), + p = pid.to_string().yellow() + ); + eprintln!("{}", pg_sys::get_pg_version_string().bold().purple()); + + // wait for the database to say its ready to start up + let reader = BufReader::new( + child + .stderr + .take() + .expect("couldn't take postmaster stderr"), + ); + + let regex = regex::Regex::new(r#"\[.*?\] \[.*?\] \[(?P.*?)\]"#).unwrap(); + let mut is_started_yet = false; + let mut lines = reader.lines(); + while let Some(Ok(line)) = lines.next() { + let session_id = match get_named_capture(®ex, "session_id", &line) { + Some(sid) => sid, + None => "NONE".to_string(), + }; + + if line.contains("database system is ready to accept connections") { + // Postgres says it's ready to go + sender.send(session_id.clone()).unwrap(); + is_started_yet = true; + } + + if !is_started_yet || line.contains("TMSG: ") { + eprintln!("{}", line.cyan()); + } + + // if line.contains("INFO: ") { + // eprintln!("{}", line.cyan()); + // } else if line.contains("WARNING: ") { + // eprintln!("{}", line.bold().yellow()); + // } else if line.contains("ERROR: ") { + // eprintln!("{}", line.bold().red()); + // } else if line.contains("statement: ") || line.contains("duration: ") { + // eprintln!("{}", line.bold().blue()); + // } else if line.contains("LOG: ") { + // eprintln!("{}", line.dimmed().white()); + // } else { + // eprintln!("{}", line.bold().purple()); + // } + + let mut loglines = loglines.lock().unwrap(); + let session_lines = loglines.entry(session_id).or_insert_with(Vec::new); + session_lines.push(line); + } + + // wait for Postgres to really finish + match child.try_wait() { + Ok(status) => { + if let Some(_status) = status { + // we exited normally + } + } + Err(e) => panic!("was going to let Postgres finish, but errored this time:\n{e}"), + } + }); + + // wait for Postgres to indicate it's ready to accept connection + // and return its pid when it is + receiver.recv().expect("Postgres failed to start") +} + +fn dropdb() -> eyre::Result<()> { + let pg_config = get_pg_config()?; + let output = Command::new( + pg_config + .dropdb_path() + .expect("unable to determine dropdb path"), + ) + .env_remove("PGDATABASE") + .env_remove("PGHOST") + .env_remove("PGPORT") + .env_remove("PGUSER") + .arg("--if-exists") + .arg("-h") + .arg(pg_config.host()) + .arg("-p") + .arg( + pg_config + .test_port() + .expect("unable to determine test port") + .to_string(), + ) + .arg(get_pg_dbname()) + .output() + .unwrap(); + + if !output.status.success() { + // maybe the database didn't exist, and if so that's okay + let stderr = String::from_utf8_lossy(output.stderr.as_slice()); + if !stderr.contains(&format!( + "ERROR: database \"{}\" does not exist", + get_pg_dbname() + )) { + // got some error we didn't expect + let stdout = String::from_utf8_lossy(output.stdout.as_slice()); + eprintln!("unexpected error (stdout):\n{stdout}"); + eprintln!("unexpected error (stderr):\n{stderr}"); + panic!("failed to drop test database"); + } + } + + Ok(()) +} + +fn create_extension() -> eyre::Result<()> { + let (mut client, _) = client(None, &get_pg_user())?; + let extension_name = get_extension_name()?; + + query_wrapper( + Some(format!("CREATE EXTENSION {} CASCADE;", &extension_name)), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err(format!( + "There was an issue creating the extension '{}' in Postgres: ", + &extension_name + ))?; + + Ok(()) +} + +fn get_extension_name() -> eyre::Result { + // We could replace this with the following if cargo adds the lib name on env var on tests/runs. + // https://github.com/rust-lang/cargo/issues/11966 + // std::env::var("CARGO_LIB_NAME") + // .unwrap_or_else(|_| panic!("CARGO_LIB_NAME environment var is unset or invalid UTF-8")) + // .replace("-", "_") + + // CARGO_MANIFEST_DIRR — The directory containing the manifest of your package. + // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates + let dir = std::env::var("CARGO_MANIFEST_DIR") + .map_err(|_| eyre!("CARGO_MANIFEST_DIR environment var is unset or invalid UTF-8"))?; + + // Cargo.toml is case sensitive atm so this is ok. + // https://github.com/rust-lang/cargo/issues/45 + let path = PathBuf::from(dir).join("Cargo.toml"); + let name = pgrx_pg_config::cargo::read_manifest(path)?.lib_name()?; + Ok(name.replace("-", "_")) +} + +fn get_pgdata_path() -> eyre::Result { + let mut target_dir = get_target_dir()?; + target_dir.push(&format!( + "pgrx-test-data-{}", + pg_sys::get_pg_major_version_num() + )); + Ok(target_dir) +} + +fn get_pid_file() -> eyre::Result { + let mut pgdata = get_pgdata_path()?; + pgdata.push("postmaster.pid"); + return Ok(pgdata); +} + +pub(crate) fn get_pg_dbname() -> &'static str { + "pgrx_tests" +} + +pub(crate) fn get_pg_user() -> String { + std::env::var("USER") + .unwrap_or_else(|_| panic!("USER environment var is unset or invalid UTF-8")) +} + +fn get_named_capture(regex: ®ex::Regex, name: &'static str, against: &str) -> Option { + match regex.captures(against) { + Some(cap) => Some(cap[name].to_string()), + None => None, + } +} + +fn get_cargo_test_features() -> eyre::Result { + let mut features = clap_cargo::Features::default(); + let cargo_user_args = get_cargo_args(); + let mut iter = cargo_user_args.iter(); + while let Some(part) = iter.next() { + match part.as_str() { + "--no-default-features" => features.no_default_features = true, + "--features" => { + let configured_features = iter.next().ok_or(eyre!( + "no `--features` specified in the cargo argument list: {:?}", + cargo_user_args + ))?; + features.features = configured_features + .split(|c: char| c.is_ascii_whitespace() || c == ',') + .map(|s| s.to_string()) + .collect(); + } + "--all-features" => features.all_features = true, + _ => {} + } + } + + Ok(features) +} + +fn get_cargo_args() -> Vec { + // setup the sysinfo crate's "System" + let mut system = System::new_all(); + system.refresh_all(); + + // starting with our process, look for the full set of arguments for the top-most "cargo" command + // in our process tree. + // + // it's possible we've been called by: + // - the user from the command-line via `cargo test ...` + // - `cargo pgrx test ...` + // - `cargo test ...` + // - some other combination with a `cargo ...` in the middle, perhaps + // + // we're interested in the first arguments the **user** gave to cargo, so `framework.rs` + // can later figure out which set of features to pass to `cargo pgrx` + let mut pid = Pid::from(std::process::id() as usize); + while let Some(process) = system.process(pid) { + // only if it's "cargo"... (This works for now, but just because `cargo` + // is at the end of the path. How *should* this handle `CARGO`?) + if process.exe().ends_with("cargo") { + // ... and only if it's "cargo test"... + if process.cmd().iter().any(|arg| arg == "test") + && !process.cmd().iter().any(|arg| arg == "pgrx") + { + // ... do we want its args + return process.cmd().iter().cloned().collect(); + } + } + + // and we want to keep going to find the top-most "cargo" process in our tree + match process.parent() { + Some(parent_pid) => pid = parent_pid, + None => break, + } + } + + Vec::new() +} + +// TODO: this would be a good place to insert a check invoking to see if +// `cargo-pgrx` is a crate in the local workspace, and use it instead. +fn cargo_pgrx() -> std::process::Command { + fn var_path(s: &str) -> Option { + std::env::var_os(s).map(PathBuf::from) + } + // Use `CARGO_PGRX` (set by `cargo-pgrx` on first run), then fall back to + // `cargo-pgrx` if it is on the path, then `$CARGO pgrx` + let cargo_pgrx = var_path("CARGO_PGRX") + .or_else(|| find_on_path("cargo-pgrx")) + .or_else(|| var_path("CARGO")) + .unwrap_or_else(|| "cargo".into()); + let mut cmd = std::process::Command::new(cargo_pgrx); + cmd.arg("pgrx"); + cmd +} + +fn find_on_path(program: &str) -> Option { + assert!(!program.contains('/')); + // Technically we should check `libc::confstr(libc::_CS_PATH)` + // when `PATH` is unset... + let paths = std::env::var_os("PATH")?; + std::env::split_paths(&paths) + .map(|p| p.join(program)) + .find(|abs| abs.exists()) +} + +fn use_valgrind() -> bool { + std::env::var_os("USE_VALGRIND").is_some_and(|s| s.len() > 0) +} diff --git a/pgrx-tests/src/framework/shutdown.rs b/pgrx-tests/src/framework/shutdown.rs new file mode 100644 index 0000000..e2168ca --- /dev/null +++ b/pgrx-tests/src/framework/shutdown.rs @@ -0,0 +1,130 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use std::panic::{self, AssertUnwindSafe, Location}; +use std::sync::{Mutex, PoisonError}; +use std::{any, io, mem, process}; + +/// Register a shutdown hook to be called when the process exits. +/// +/// Note that shutdown hooks are only run on the client, so must be added from +/// your `setup` function, not the `#[pg_test]` itself. +#[track_caller] +pub fn add_shutdown_hook(func: F) +where + F: Send + 'static, +{ + SHUTDOWN_HOOKS + .lock() + .unwrap_or_else(PoisonError::into_inner) + .push(ShutdownHook { + source: Location::caller(), + callback: Box::new(func), + }); +} + +pub(super) fn register_shutdown_hook() { + unsafe { + libc::atexit(run_shutdown_hooks); + } +} + +/// The `atexit` callback. +/// +/// If we panic from `atexit`, we end up causing `exit` to unwind. Unwinding +/// from a nounwind + noreturn function can cause some destructors to run twice, +/// causing (for example) libtest to SIGSEGV. +/// +/// This ends up looking like a memory bug in either `pgrx` or the user code, and +/// is very hard to track down, so we go to some lengths to prevent it. +/// Essentially: +/// +/// - Panics in each user hook are caught and reported. +/// - As a stop-gap an abort-on-drop panic guard is used to ensure there isn't a +/// place we missed. +/// +/// We also write to stderr directly instead, since otherwise our output will +/// sometimes be redirected. +extern "C" fn run_shutdown_hooks() { + let guard = PanicGuard; + let mut any_panicked = false; + let mut hooks = SHUTDOWN_HOOKS + .lock() + .unwrap_or_else(PoisonError::into_inner); + // Note: run hooks in the opposite order they were registered. + for hook in mem::take(&mut *hooks).into_iter().rev() { + any_panicked |= hook.run().is_err(); + } + if any_panicked { + write_stderr("error: one or more shutdown hooks panicked (see `stderr` for details).\n"); + std::process::abort() + } + mem::forget(guard); +} + +/// Prevent panics in a block of code. +/// +/// Prints a message and aborts in its drop. Intended usage is like: +/// ```ignore +/// let guard = PanicGuard; +/// // ...code that absolutely must never unwind goes here... +/// core::mem::forget(guard); +/// ``` +struct PanicGuard; +impl Drop for PanicGuard { + fn drop(&mut self) { + write_stderr("Failed to catch panic in the `atexit` callback, aborting!\n"); + process::abort(); + } +} + +static SHUTDOWN_HOOKS: Mutex> = Mutex::new(Vec::new()); + +struct ShutdownHook { + source: &'static Location<'static>, + callback: Box, +} + +impl ShutdownHook { + fn run(self) -> Result<(), ()> { + let Self { source, callback } = self; + let result = panic::catch_unwind(AssertUnwindSafe(callback)); + if let Err(e) = result { + let msg = failure_message(&e); + write_stderr(&format!( + "error: shutdown hook (registered at {source}) panicked: {msg}\n" + )); + Err(()) + } else { + Ok(()) + } + } +} + +fn failure_message(e: &(dyn any::Any + Send)) -> &str { + if let Some(&msg) = e.downcast_ref::<&'static str>() { + msg + } else if let Some(msg) = e.downcast_ref::() { + msg.as_str() + } else { + "" + } +} + +/// Write to stderr, bypassing libtest's output redirection. Doesn't append `\n`. +fn write_stderr(s: &str) { + loop { + let res = unsafe { libc::write(libc::STDERR_FILENO, s.as_ptr().cast(), s.len()) }; + // Handle EINTR to ensure we don't drop messages. + // `Error::last_os_error()` just reads from errno, so it's fine to use here. + if res >= 0 || io::Error::last_os_error().kind() != io::ErrorKind::Interrupted { + break; + } + } +} diff --git a/pgrx-tests/src/lib.rs b/pgrx-tests/src/lib.rs new file mode 100644 index 0000000..86905f8 --- /dev/null +++ b/pgrx-tests/src/lib.rs @@ -0,0 +1,11 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +mod framework; +pub use framework::*; diff --git a/src/gucs.rs b/src/gucs.rs new file mode 100644 index 0000000..bfafc73 --- /dev/null +++ b/src/gucs.rs @@ -0,0 +1,29 @@ +use pgrx::*; +use std::ffi::CStr; + +pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; +pub static NEON_AUTH_JWK: GucSetting> = + GucSetting::>::new(None); +pub static NEON_AUTH_JWT_RUNTIME_PARAM: &str = "neon.auth.jwt"; +pub static NEON_AUTH_JWT: GucSetting> = + GucSetting::>::new(None); + +pub fn init() { + GucRegistry::define_string_guc( + NEON_AUTH_JWK_RUNTIME_PARAM, + "JSON Web Key (JWK) used for JWT validation", + "Generated per connection by Neon local proxy", + &NEON_AUTH_JWK, + GucContext::Backend, + GucFlags::NOT_WHILE_SEC_REST | GucFlags::NO_RESET_ALL, + ); + + GucRegistry::define_string_guc( + NEON_AUTH_JWT_RUNTIME_PARAM, + "JSON Web Token (JWT) used for query authorization", + "Represents authenticated user session related claims like user ID", + &NEON_AUTH_JWT, + GucContext::Userset, + GucFlags::NOT_WHILE_SEC_REST, + ); +} diff --git a/src/lib.rs b/src/lib.rs index d32bd8a..762713b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +mod gucs; + use pgrx::prelude::*; pgrx::pg_module_magic!(); @@ -13,6 +15,12 @@ macro_rules! error_code { }}; } +#[allow(non_snake_case)] +#[pg_guard] +pub unsafe extern "C" fn _PG_init() { + gucs::init(); +} + #[pg_schema] pub mod auth { use std::cell::{OnceCell, RefCell}; @@ -27,20 +35,18 @@ pub mod auth { use pgrx::JsonB; use serde::de::DeserializeOwned; + use crate::gucs::{ + NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, + }; + type Object = serde_json::Map; thread_local! { - static JWK: OnceCell = const { OnceCell::new() }; + static JWK: OnceCell = const { OnceCell::new() }; static JWT: RefCell> = const { RefCell::new(None) }; static JTI: RefCell = const { RefCell::new(0) }; } - #[derive(Clone)] - struct Key { - kid: i64, - key: VerifyingKey, - } - /// Set the public key and key ID for this postgres session. /// /// # Panics @@ -48,15 +54,32 @@ pub mod auth { /// This function will panic if called multiple times per session. /// This is to prevent replacing the key mid-session. #[pg_extern] - pub fn init(kid: i64, s: JsonB) { - let key: JwkEcKey = serde_json::from_value(s.0).unwrap_or_else(|e| { + pub fn init() { + let jwk = NEON_AUTH_JWK + .get() + .unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWK_RUNTIME_PARAM) + ) + }) + .to_str() + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWK_RUNTIME_PARAM), + e.to_string(), + ) + }); + + let jwk: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", e.to_string(), ) }); - let key = PublicKey::from_jwk(&key).unwrap_or_else(|p256::elliptic_curve::Error| { + let key = PublicKey::from_jwk(&jwk).unwrap_or_else(|p256::elliptic_curve::Error| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", @@ -64,7 +87,7 @@ pub mod auth { }); let key = VerifyingKey::from(key); JWK.with(|j| { - if j.set(Key { kid, key }).is_err() { + if j.set(key).is_err() { error_code!( PgSqlErrorCode::ERRCODE_UNIQUE_VIOLATION, "JWK state can only be set once per session.", @@ -73,7 +96,7 @@ pub mod auth { }) } - fn verify_signature(key: &Key, body: &str, sig: &str) { + fn verify_signature(key: &VerifyingKey, body: &str, sig: &str) { let mut sig_bytes = GenericArray::default(); Base64UrlUnpadded::decode(sig, &mut sig_bytes).unwrap_or_else(|_| { error_code!( @@ -88,7 +111,7 @@ pub mod auth { ) }); - key.key.verify(body.as_bytes(), &sig).unwrap_or_else(|_| { + key.verify(body.as_bytes(), &sig).unwrap_or_else(|_| { error_code!( PgSqlErrorCode::ERRCODE_CHECK_VIOLATION, "invalid JWT signature", @@ -96,22 +119,6 @@ pub mod auth { }); } - fn verify_key_id(key: &Key, header: &Object) { - let kid = header - .get("kid") - .and_then(|x| x.as_i64()) - .unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, - "JWT header must contain a valid 'kid' (key ID)", - ) - }); - - if key.kid != kid { - error_code!(PgSqlErrorCode::ERRCODE_CHECK_VIOLATION, "Key ID mismatch"); - } - } - fn verify_token_id(payload: &Object) -> i64 { let jti = payload .get("jti") @@ -185,32 +192,63 @@ pub mod auth { /// /// This function will panic if the JWT could not be verified. #[pg_extern] - pub fn jwt_session_init(s: &str) { + pub fn jwt_session_init(jwt: &str) { + Spi::run( + format!( + "SET {} = {}", + NEON_AUTH_JWT_RUNTIME_PARAM, + spi::quote_literal(jwt) + ) + .as_str(), + ) + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, + format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); + set_jwt_cache() + } + + fn set_jwt_cache() { + let jwt = NEON_AUTH_JWT + .get() + .unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWT_RUNTIME_PARAM) + ) + }) + .to_str() + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); let key = JWK.with(|b| { - b.get() - .unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_NOT_NULL_VIOLATION, - "JWK state has not been initialised", - ) - }) - .clone() + *b.get().unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NOT_NULL_VIOLATION, + "JWK state has not been initialised", + ) + }) }); - let (body, sig) = s.rsplit_once('.').unwrap_or_else(|| { + let (body, sig) = jwt.rsplit_once('.').unwrap_or_else(|| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid JWT encoding", ) }); - let (header, payload) = body.split_once('.').unwrap_or_else(|| { + let (_, payload) = body.split_once('.').unwrap_or_else(|| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid JWT encoding", ) }); - let header: Object = json_base64_decode(header); - verify_key_id(&key, &header); verify_signature(&key, body, sig); let payload: Object = json_base64_decode(payload); @@ -224,20 +262,28 @@ pub mod auth { /// Extract a value from the shared state. #[pg_extern] - pub fn session(s: &str) -> JsonB { + pub fn session() -> JsonB { + JWK.with(|j| { + if j.get().is_none() { + // assuming that running as bgworker + init(); + set_jwt_cache(); + } + }); + JWT.with_borrow(|j| { JsonB( j.as_ref() - .and_then(|j| j.get(s).cloned()) - .unwrap_or(serde_json::Value::Null), + .cloned() + .map_or(serde_json::Value::Null, serde_json::Value::Object), ) }) } #[pg_extern] - pub fn user_id() -> String { - match session("sub").0 { - serde_json::Value::String(s) => s, + pub fn user_id() -> Option { + match session().0.get("sub")? { + serde_json::Value::String(s) => Some(s.clone()), _ => error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid subject claim in the JWT" @@ -262,162 +308,3 @@ pub mod auth { }) } } - -#[cfg(any(test, feature = "pg_test"))] -#[pg_schema] -mod tests { - use std::fmt::Display; - use std::time::{SystemTime, UNIX_EPOCH}; - - use base64ct::{Base64UrlUnpadded, Encoding}; - use p256::ecdsa::signature::Signer; - use p256::{ - ecdsa::{Signature, SigningKey}, - elliptic_curve::JwkEcKey, - }; - use p256::{NistP256, PublicKey}; - use pgrx::{prelude::*, JsonB}; - use rand::rngs::OsRng; - use serde_json::json; - - use crate::auth; - - fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { - let header = Base64UrlUnpadded::encode_string(header.as_bytes()); - let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); - - let message = format!("{header}.{payload}"); - let sig: Signature = sk.sign(message.as_bytes()); - let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); - format!("{message}.{base64_sig}") - } - - #[pg_test] - #[should_panic = "JWK state can only be set once per session."] - fn init_jwk_twice() { - let sk = SigningKey::random(&mut OsRng); - let point = sk.verifying_key().to_encoded_point(false); - let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); - let jwk = serde_json::to_value(&jwk).unwrap(); - - auth::init(1, JsonB(jwk.clone())); - auth::init(2, JsonB(jwk)); - } - - #[pg_test] - #[should_panic = "Key ID mismatch"] - fn wrong_pid() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); - } - - #[pg_test] - #[should_panic = "Token ID must be strictly monotonically increasing"] - fn wrong_txid() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); - } - - #[pg_test] - #[should_panic = "Token used before it is ready"] - fn invalid_nbf() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - auth::jwt_session_init(&sign_jwt( - &sk, - r#"{"kid":1}"#, - json!({"jti": 1, "nbf": now + 10}), - )); - } - - #[pg_test] - #[should_panic = "Token used after it has expired"] - fn invalid_exp() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - auth::jwt_session_init(&sign_jwt( - &sk, - r#"{"kid":1}"#, - json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), - )); - } - - #[pg_test] - fn valid_time() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - let header = r#"{"kid":1}"#; - - auth::jwt_session_init(&sign_jwt( - &sk, - header, - json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), - )); - auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10}))); - auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10}))); - } - - #[pg_test] - fn test_pg_session_jwt() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); - - auth::init(1, jwk); - let header = r#"{"kid":1}"#; - - auth::jwt_session_init(&sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#)); - assert_eq!(auth::user_id(), "foo"); - - auth::jwt_session_init(&sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#)); - assert_eq!(auth::user_id(), "bar"); - } -} - -/// This module is required by `cargo pgrx test` invocations. -/// It must be visible at the root of your extension crate. -#[cfg(test)] -pub mod pg_test { - pub fn setup(_options: Vec<&str>) { - // perform one-off initialization when the pg_test framework starts - } - - pub fn postgresql_conf_options() -> Vec<&'static str> { - // return any postgresql.conf settings that are required for your tests - vec![] - } -} diff --git a/tests/pg_session_jwt.rs b/tests/pg_session_jwt.rs new file mode 100644 index 0000000..b40285c --- /dev/null +++ b/tests/pg_session_jwt.rs @@ -0,0 +1,202 @@ +use std::process::ExitCode; +use std::time::{SystemTime, UNIX_EPOCH}; + +use base64ct::{Base64UrlUnpadded, Encoding}; +use libtest_mimic::{run, Trial}; +use p256::ecdsa::signature::Signer; +use p256::{ + ecdsa::{Signature, SigningKey}, + elliptic_curve::JwkEcKey, + NistP256, +}; +use rand::rngs::OsRng; +use serde_json::json; + +fn main() -> ExitCode { + let mut args = libtest_mimic::Arguments::from_args(); + // fixes concurrent update failures + args.test_threads = Some(1); + + let mut tests = vec![]; + + let err = "Token ID must be strictly monotonically increasing."; + tests.push(test_fn("wrong_txid", Some(err), wrong_txid)); + + let err = "Token used before it is ready"; + tests.push(test_fn("invalid_nbf", Some(err), invalid_nbf)); + + let err = "Token used after it has expired"; + tests.push(test_fn("invalid_exp", Some(err), invalid_exp)); + + tests.push(test_fn("valid_time", None, valid_time)); + tests.push(test_fn("test_pg_session_jwt", None, test_pg_session_jwt)); + tests.push(test_fn("test_bgworker", None, test_bgworker)); + + run(&args, tests).exit_code() +} + +// bgworker process exits after execution, because of that we don't need to test case for more +// than one JWT +fn test_fn(name: &str, error: Option<&'static str>, f: F) -> Trial +where + F: for<'a, 'b> FnOnce(&'a SigningKey, &'b mut postgres::Client) -> Result<(), postgres::Error> + + Send + + 'static, +{ + let sk = SigningKey::random(&mut OsRng); + let jwk = create_jwk(&sk); + let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); + + Trial::test(name, move || { + pgrx_tests::run_test(Some(&options), error, vec![], move |tx| f(&sk, tx)) + .map_err(libtest_mimic::Failed::from) + }) +} + +fn wrong_txid(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let jwt1 = sign_jwt(sk, r#"{"kid":1}"#, r#"{"jti":1}"#); + let jwt2 = sign_jwt(sk, r#"{"kid":1}"#, r#"{"jti":2}"#); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + + Ok(()) +} + +fn invalid_nbf(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let jwt = sign_jwt(sk, r#"{"kid":1}"#, json!({"jti": 1, "nbf": now + 10})); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt])?; + + Ok(()) +} + +fn invalid_exp(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let jwt = sign_jwt( + sk, + r#"{"kid":1}"#, + json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), + ); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt])?; + + Ok(()) +} + +fn valid_time(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let header = r#"{"kid":1}"#; + let jwt1 = sign_jwt( + sk, + header, + json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), + ); + let jwt2 = sign_jwt(sk, header, json!({"jti": 2, "nbf": now - 10})); + let jwt3 = sign_jwt(sk, header, json!({"jti": 3, "exp": now + 10})); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt3])?; + + Ok(()) +} + +fn test_pg_session_jwt(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let header = r#"{"kid":1}"#; + let jwt1 = sign_jwt(sk, header, r#"{"sub":"foo","jti":1}"#); + let jwt2 = sign_jwt(sk, header, r#"{"sub":"bar","jti":2}"#); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "foo"); + + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "bar"); + + Ok(()) +} + +// bgworker process exits after execution, because of that we don't need to test case for more +// than one JWT +fn test_bgworker(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let header = r#"{"kid":1}"#; + let jwt = sign_jwt(sk, header, r#"{"sub":"foo","jti":1}"#); + + tx.execute(&format!("set neon.auth.jwt = '{jwt}'"), &[])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "foo"); + + Ok(()) +} + +// fn discard() -> eyre::Result<()> { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = create_jwk(&sk, 1); +// let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); + +// let header = r#"{"kid":1}"#; +// let jwt1 = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); +// let jwt2 = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); + +// pgrx_tests::run_test(Some(&options), None, vec![], |tx| { +// tx.execute("select auth.init()", &[])?; +// tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), Some("foo")); + +// tx.simple_query("reset neon.auth.jwt")?; + +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), None); + +// tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), Some("bar")); + +// Ok(()) +// }) +// } + +static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; + +fn sign_jwt(sk: &SigningKey, header: &str, payload: impl ToString) -> String { + let header = Base64UrlUnpadded::encode_string(header.as_bytes()); + let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); + + let message = format!("{header}.{payload}"); + let sig: Signature = sk.sign(message.as_bytes()); + let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); + format!("{message}.{base64_sig}") +} + +fn create_jwk(sk: &SigningKey) -> String { + let point = sk.verifying_key().to_encoded_point(false); + let key = JwkEcKey::from_encoded_point::(&point).unwrap(); + serde_json::to_string(&key).unwrap() +}