From 19ccb07e32794a2df0ddb81ff5ca94767923e653 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Fri, 27 Oct 2023 12:24:18 -0400 Subject: [PATCH] PostgreSQL wire-compatible API (#83) * pg exploration * improvements * much progress, this is a little insane * this is a little insane, should support more pg clients now * implement an extra column, make it less panicky * clean up mess from failed experiment * remove unused files from old implementation * send subscriptions too * need to unquote a few things sometimes * handle interrupts, fix last few bugs for compliant clients * update changelog and docs * remove unused crates in corro-pg --- CHANGELOG.md | 2 + Cargo.lock | 293 ++- Cargo.toml | 2 +- crates/corro-agent/Cargo.toml | 1 + crates/corro-agent/src/agent.rs | 22 +- crates/corro-agent/src/api/peer.rs | 1 + crates/corro-agent/src/api/public/pubsub.rs | 6 +- crates/corro-agent/src/broadcast/mod.rs | 4 +- crates/corro-pg/Cargo.toml | 28 + crates/corro-pg/src/lib.rs | 2456 +++++++++++++++++++ crates/corro-pg/src/sql_state.rs | 1336 ++++++++++ crates/corro-pg/src/vtab/mod.rs | 2 + crates/corro-pg/src/vtab/pg_range.rs | 92 + crates/corro-pg/src/vtab/pg_type.rs | 324 +++ crates/corro-types/src/agent.rs | 57 +- crates/corro-types/src/config.rs | 9 + crates/corro-types/src/pubsub.rs | 25 +- crates/corro-types/src/schema.rs | 329 +-- doc/SUMMARY.md | 1 + doc/api/pg.md | 15 + 20 files changed, 4809 insertions(+), 196 deletions(-) create mode 100644 crates/corro-pg/Cargo.toml create mode 100644 crates/corro-pg/src/lib.rs create mode 100644 crates/corro-pg/src/sql_state.rs create mode 100644 crates/corro-pg/src/vtab/mod.rs create mode 100644 crates/corro-pg/src/vtab/pg_range.rs create mode 100644 crates/corro-pg/src/vtab/pg_type.rs create mode 100644 doc/api/pg.md diff --git a/CHANGELOG.md b/CHANGELOG.md index cfa84f0d..fc41e49f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Implement a PostgreSQL wire protocol (v3) compatible API ([#83](../../pull/83)) +- Accept _all_ JSON types for SQLite params input ([#82](../../pull/82)) - Parallel synchronization w/ many deadlock and bug fixes ([#78](../../pull/78)) - Upgraded to cr-sqlite 0.16.0 (unreleased) ([#75](../../pull/75)) - Rewrite compaction logic to be more correct and efficient ([#74](../../pull/74)) diff --git a/Cargo.lock b/Cargo.lock index a230a58a..b2d73dde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,6 +50,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -319,6 +325,22 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bcder" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf16bec990f8ea25cab661199904ef452fcf11f565c404ce6cffbdf3f8cbbc47" +dependencies = [ + "bytes", + "smallvec", +] + [[package]] name = "beef" version = "0.5.2" @@ -503,15 +525,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.24" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ + "android-tzdata", "iana-time-zone", - "num-integer", + "js-sys", "num-traits", "serde", - "winapi", + "wasm-bindgen", + "windows-targets 0.48.0", ] [[package]] @@ -617,6 +641,12 @@ dependencies = [ "toml", ] +[[package]] +name = "const-oid" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" + [[package]] name = "const-random" version = "0.1.15" @@ -709,6 +739,7 @@ dependencies = [ "camino", "compact_str 0.7.0", "config", + "corro-pg", "corro-speedy", "corro-tests", "corro-types", @@ -791,6 +822,32 @@ dependencies = [ "uuid", ] +[[package]] +name = "corro-pg" +version = "0.1.0" +dependencies = [ + "bytes", + "compact_str 0.7.0", + "corro-tests", + "corro-types", + "fallible-iterator", + "futures", + "pgwire", + "postgres-types", + "rusqlite", + "spawn", + "sqlite3-parser", + "tempfile", + "thiserror", + "time", + "tokio", + "tokio-postgres", + "tokio-util", + "tracing", + "tracing-subscriber", + "tripwire", +] + [[package]] name = "corro-speedy" version = "0.8.7" @@ -1167,6 +1224,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "der-parser" version = "8.2.0" @@ -1181,6 +1248,17 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -1214,6 +1292,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -1362,6 +1441,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" + [[package]] name = "fnv" version = "1.0.7" @@ -1510,6 +1595,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getset" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "gimli" version = "0.27.2" @@ -1629,6 +1726,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "hostname" version = "0.3.1" @@ -2090,6 +2196,22 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -2518,6 +2640,34 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "pgwire" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d04982366efd653d4365175426acbabd55efb07231869e92b9e1f5b3faf7df" +dependencies = [ + "async-trait", + "base64 0.21.0", + "bytes", + "chrono", + "derive-new", + "futures", + "getset", + "hex", + "log", + "md5", + "postgres-types", + "rand", + "ring", + "stringprep", + "thiserror", + "time", + "tokio", + "tokio-rustls", + "tokio-util", + "x509-certificate", +] + [[package]] name = "phf" version = "0.11.1" @@ -2601,6 +2751,37 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f602a0d1e09a48e4f8e8b4d4042e32807c3676da31f2ecabeac9f96226ec6c45" +[[package]] +name = "postgres-protocol" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49b6c5ef183cd3ab4ba005f1ca64c21e8bd97ce4699cfea9e8d9a2c4958ca520" +dependencies = [ + "base64 0.21.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c" +dependencies = [ + "bytes", + "chrono", + "fallible-iterator", + "postgres-protocol", + "time", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -3261,6 +3442,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -3289,6 +3481,12 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" + [[package]] name = "siphasher" version = "0.3.10" @@ -3389,6 +3587,16 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spki" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlite-pool" version = "0.1.0" @@ -3434,6 +3642,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stringprep" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" +dependencies = [ + "finl_unicode", + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "strsim" version = "0.10.0" @@ -3462,6 +3681,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "supports-color" version = "1.3.1" @@ -3732,6 +3957,32 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-postgres" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d340244b32d920260ae7448cb72b6e238bddc3d4f7603394e7dd46ed8e48f5b8" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand", + "socket2 0.5.3", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.24.0" @@ -4322,6 +4573,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "whoami" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + [[package]] name = "widestring" version = "0.5.1" @@ -4524,6 +4785,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "x509-certificate" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0" +dependencies = [ + "bcder", + "bytes", + "chrono", + "der", + "hex", + "pem", + "ring", + "signature", + "spki", + "thiserror", +] + [[package]] name = "x509-parser" version = "0.15.0" @@ -4565,3 +4844,9 @@ checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" dependencies = [ "time", ] + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/Cargo.toml b/Cargo.toml index 75c430c5..0e77143a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ rand = { version = "0.8.5", features = ["small_rng"] } rangemap = { version = "1.3.0" } rcgen = { version = "0.11.1", features = ["x509-parser"] } rhai = { version = "1.15.1", features = ["sync"] } -rusqlite = { version = "0.29.0", features = ["serde_json", "time", "bundled", "uuid", "array", "load_extension", "column_decltype"] } +rusqlite = { version = "0.29.0", features = ["serde_json", "time", "bundled", "uuid", "array", "load_extension", "column_decltype", "vtab"] } rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] } rustls-pemfile = "1.0.2" seahash = "4.1.0" diff --git a/crates/corro-agent/Cargo.toml b/crates/corro-agent/Cargo.toml index 1e6e2561..eb70a346 100644 --- a/crates/corro-agent/Cargo.toml +++ b/crates/corro-agent/Cargo.toml @@ -52,6 +52,7 @@ tripwire = { path = "../tripwire" } trust-dns-resolver = { workspace = true } uhlc = { workspace = true } uuid = { workspace = true } +corro-pg = { path = "../corro-pg" } [dev-dependencies] corro-tests = { path = "../corro-tests" } diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index 29b5b5b4..96b8c7c4 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -128,7 +128,9 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age let schema = { let mut conn = pool.write_priority().await?; migrate(&mut conn)?; - init_schema(&conn)? + let mut schema = init_schema(&conn)?; + schema.constrain()?; + schema }; { @@ -297,7 +299,10 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age pub async fn start(conf: Config, tripwire: Tripwire) -> eyre::Result { let (agent, opts) = setup(conf, tripwire.clone()).await?; - tokio::spawn(run(agent.clone(), opts).inspect(|_| info!("corrosion agent run is done"))); + tokio::spawn(run(agent.clone(), opts).inspect(|res| match res { + Ok(_) => info!("corrosion agent run is done"), + Err(e) => error!("running corrosion agent failed: {e}"), + })); Ok(agent) } @@ -317,6 +322,15 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> { rtt_rx, } = opts; + if let Some(pg_conf) = agent.config().api.pg.clone() { + info!("Starting PostgreSQL wire-compatible server"); + let pg_server = corro_pg::start(agent.clone(), pg_conf, tripwire.clone()).await?; + info!( + "Started PostgreSQL wire-compatible server, listening at {}", + pg_server.local_addr + ); + } + let mut matcher_id_cache = MatcherIdCache::default(); let mut matcher_bcast_cache = MatcherBroadcastCache::default(); @@ -345,7 +359,7 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> { }; for (id, sql) in rows { - let conn = agent.pool().dedicated().await?; + let conn = block_in_place(|| agent.pool().dedicated())?; let (evt_tx, evt_rx) = channel(512); match Matcher::restore(id, &agent.schema().read(), conn, evt_tx, &sql) { Ok(handle) => { @@ -2408,6 +2422,7 @@ async fn handle_changes( } // drain and process current changes! + #[allow(clippy::drain_collect)] if let Err(e) = process_multiple_changes(&agent, buf.drain(..).collect()).await { error!("could not process multiple changes: {e}"); } @@ -2426,6 +2441,7 @@ async fn handle_changes( buf.push((change, src)); if count >= MIN_CHANGES_CHUNK { // drain and process current changes! + #[allow(clippy::drain_collect)] if let Err(e) = process_multiple_changes(&agent, buf.drain(..).collect()).await { error!("could not process multiple changes: {e}"); } diff --git a/crates/corro-agent/src/api/peer.rs b/crates/corro-agent/src/api/peer.rs index 3305cdc3..5a0a814e 100644 --- a/crates/corro-agent/src/api/peer.rs +++ b/crates/corro-agent/src/api/peer.rs @@ -1021,6 +1021,7 @@ pub async fn parallel_sync( debug!("collected member needs and such!"); + #[allow(clippy::manual_try_fold)] let syncers = results.into_iter().fold(Ok(vec![]), |agg, (actor_id, addr, res)| { match res { Ok((needs, tx, read)) => { diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index bcdfe4d8..02352cd9 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -186,7 +186,7 @@ pub async fn process_sub_channel( }; // get a dedicated connection - let conn = match agent.pool().dedicated().await { + let conn = match agent.pool().dedicated() { Ok(conn) => conn, Err(e) => { error!("could not acquire dedicated connection for subscription cleanup: {e}"); @@ -443,7 +443,7 @@ pub async fn catch_up_sub( let last_query_event = { let mut buf = BytesMut::new(); - let mut conn = match agent.pool().dedicated().await { + let mut conn = match agent.pool().dedicated() { Ok(conn) => conn, Err(e) => { evt_tx.send(error_to_query_event_bytes(&mut buf, e)).await?; @@ -547,7 +547,7 @@ pub async fn upsert_sub( return Err(MatcherUpsertError::SubFromWithoutMatcher); } - let conn = agent.pool().dedicated().await?; + let conn = agent.pool().dedicated()?; let (evt_tx, evt_rx) = mpsc::channel(512); diff --git a/crates/corro-agent/src/broadcast/mod.rs b/crates/corro-agent/src/broadcast/mod.rs index 1170aaa7..b2bbe1fc 100644 --- a/crates/corro-agent/src/broadcast/mod.rs +++ b/crates/corro-agent/src/broadcast/mod.rs @@ -231,7 +231,7 @@ pub fn runtime_loop( .map(serde_json::Value::from) .collect::>() }) - .unwrap_or(vec![]), + .unwrap_or_default(), ), )), Err(e) => { @@ -425,7 +425,7 @@ pub fn runtime_loop( .map(serde_json::Value::from) .collect::>() }) - .unwrap_or(vec![]), + .unwrap_or_default(), ); let upserted = tx.prepare_cached("INSERT INTO __corro_members (actor_id, address, state, foca_state, rtts) diff --git a/crates/corro-pg/Cargo.toml b/crates/corro-pg/Cargo.toml new file mode 100644 index 00000000..72a73b8c --- /dev/null +++ b/crates/corro-pg/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "corro-pg" +version = "0.1.0" +edition = "2021" + +[dependencies] +bytes = { workspace = true } +compact_str = { workspace = true } +corro-types = { path = "../corro-types" } +fallible-iterator = { workspace = true } +futures = { workspace = true } +pgwire = { version = "0.16.1" } +postgres-types = { version = "0.2", features = ["with-time-0_3"] } +rusqlite = { workspace = true } +spawn = { path = "../spawn" } +sqlite3-parser = { workspace = true } +tempfile = { workspace = true } +thiserror = { workspace = true } +time = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +tracing = { workspace = true } +tripwire = { path = "../tripwire" } + +[dev-dependencies] +corro-tests = { path = "../corro-tests" } +tokio-postgres = { version = "0.7.10" } +tracing-subscriber = { workspace = true } \ No newline at end of file diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs new file mode 100644 index 00000000..e742c093 --- /dev/null +++ b/crates/corro-pg/src/lib.rs @@ -0,0 +1,2456 @@ +pub mod sql_state; +mod vtab; + +use std::{ + collections::{HashMap, VecDeque}, + future::poll_fn, + net::SocketAddr, + sync::Arc, +}; + +use bytes::Buf; +use compact_str::CompactString; +use corro_types::{ + agent::{Agent, KnownDbVersion}, + broadcast::Timestamp, + config::PgConfig, + schema::{parse_sql, Schema, SchemaError, SqliteType, Table}, + sqlite::SqlitePoolError, +}; +use fallible_iterator::FallibleIterator; +use futures::{SinkExt, StreamExt}; +use pgwire::{ + api::{ + results::{DataRowEncoder, FieldFormat, FieldInfo, Tag}, + ClientInfo, ClientInfoHolder, + }, + error::{ErrorInfo, PgWireError}, + messages::{ + data::{NoData, ParameterDescription, RowDescription}, + extendedquery::{BindComplete, CloseComplete, ParseComplete, PortalSuspended}, + response::{ + EmptyQueryResponse, ReadyForQuery, READY_STATUS_IDLE, READY_STATUS_TRANSACTION_BLOCK, + }, + startup::{ParameterStatus, SslRequest}, + PgWireBackendMessage, PgWireFrontendMessage, + }, + tokio::PgWireMessageServerCodec, +}; +use postgres_types::{FromSql, Type}; +use rusqlite::{named_params, types::ValueRef, vtab::eponymous_only_module, Connection, Statement}; +use spawn::spawn_counted; +use sqlite3_parser::ast::{ + As, Cmd, CreateTableBody, Expr, FromClause, Id, InsertBody, Name, OneSelect, ResultColumn, + Select, SelectTable, Stmt, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, ReadBuf}, + net::{TcpListener, TcpStream}, + sync::mpsc::channel, + task::block_in_place, +}; +use tokio_util::{codec::Framed, sync::CancellationToken}; +use tracing::{debug, error, info, trace, warn}; +use tripwire::{Outcome, PreemptibleFutureExt, Tripwire}; + +use crate::{ + sql_state::SqlState, + vtab::{pg_range::PgRangeTable, pg_type::PgTypeTable}, +}; + +type BoxError = Box; + +pub struct PgServer { + pub local_addr: SocketAddr, +} + +enum BackendResponse { + Message { + message: PgWireBackendMessage, + flush: bool, + }, + Flush, +} + +impl From<(PgWireBackendMessage, bool)> for BackendResponse { + fn from((message, flush): (PgWireBackendMessage, bool)) -> Self { + Self::Message { message, flush } + } +} + +#[derive(Clone, Copy, Debug)] +enum StmtTag { + Select, + InsertAsSelect, + + Insert, + Update, + Delete, + + Alter, + Analyze, + Attach, + Begin, + Commit, + Create, + Detach, + Drop, + Pragma, + Reindex, + Release, + Rollback, + Savepoint, + Vacuum, + + Other, +} + +impl StmtTag { + fn returns_rows_affected(&self) -> bool { + matches!(self, StmtTag::Insert | StmtTag::Update | StmtTag::Delete) + } + fn returns_num_rows(&self) -> bool { + matches!(self, StmtTag::Select | StmtTag::InsertAsSelect) + } + pub fn tag(&self, rows: Option) -> Tag { + match self { + StmtTag::Select => Tag::new_for_execution("SELECT", rows), + StmtTag::InsertAsSelect | StmtTag::Insert => Tag::new_for_execution("INSERT", rows), + StmtTag::Update => Tag::new_for_execution("UPDATE", rows), + StmtTag::Delete => Tag::new_for_execution("DELETE", rows), + StmtTag::Alter => Tag::new_for_execution("ALTER", rows), + StmtTag::Analyze => Tag::new_for_execution("ANALYZE", rows), + StmtTag::Attach => Tag::new_for_execution("ATTACH", rows), + StmtTag::Begin => Tag::new_for_execution("BEGIN", rows), + StmtTag::Commit => Tag::new_for_execution("COMMIT", rows), + StmtTag::Create => Tag::new_for_execution("CREATE", rows), + StmtTag::Detach => Tag::new_for_execution("DETACH", rows), + StmtTag::Drop => Tag::new_for_execution("DROP", rows), + StmtTag::Pragma => Tag::new_for_execution("PRAGMA", rows), + StmtTag::Reindex => Tag::new_for_execution("REINDEX", rows), + StmtTag::Release => Tag::new_for_execution("RELEASE", rows), + StmtTag::Rollback => Tag::new_for_execution("ROLLBACK", rows), + StmtTag::Savepoint => Tag::new_for_execution("SAVEPOINT", rows), + StmtTag::Vacuum => Tag::new_for_execution("VACUUM", rows), + StmtTag::Other => Tag::new_for_execution("OK", rows), + } + } +} + +enum Prepared { + Empty, + NonEmpty { + sql: String, + param_types: Vec, + fields: Vec, + tag: StmtTag, + }, +} + +enum Portal<'a> { + Empty { + stmt_name: CompactString, + }, + Parsed { + stmt_name: CompactString, + stmt: Statement<'a>, + result_formats: Vec, + tag: StmtTag, + }, +} + +impl<'a> Portal<'a> { + fn stmt_name(&self) -> &str { + match self { + Portal::Empty { stmt_name } | Portal::Parsed { stmt_name, .. } => stmt_name.as_str(), + } + } +} + +#[derive(Clone, Debug)] +struct ParsedCmd(Cmd); + +impl ParsedCmd { + pub fn is_begin(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Begin(_, _))) + } + pub fn is_commit(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Commit(_))) + } + pub fn is_rollback(&self) -> bool { + matches!(self.0, Cmd::Stmt(Stmt::Rollback { .. })) + } + + fn tag(&self) -> StmtTag { + match &self.0 { + Cmd::Stmt(stmt) => match stmt { + Stmt::Select(_) => StmtTag::Select, + Stmt::CreateTable { + body: CreateTableBody::AsSelect(_), + .. + } => StmtTag::InsertAsSelect, + Stmt::AlterTable(_, _) => StmtTag::Alter, + Stmt::Analyze(_) => StmtTag::Analyze, + Stmt::Attach { .. } => StmtTag::Attach, + Stmt::Begin(_, _) => StmtTag::Begin, + Stmt::Commit(_) => StmtTag::Commit, + Stmt::CreateIndex { .. } + | Stmt::CreateTable { .. } + | Stmt::CreateTrigger { .. } + | Stmt::CreateView { .. } + | Stmt::CreateVirtualTable { .. } => StmtTag::Create, + Stmt::Delete { .. } => StmtTag::Delete, + Stmt::Detach(_) => StmtTag::Detach, + Stmt::DropIndex { .. } + | Stmt::DropTable { .. } + | Stmt::DropTrigger { .. } + | Stmt::DropView { .. } => StmtTag::Drop, + Stmt::Insert { .. } => StmtTag::Insert, + Stmt::Pragma(_, _) => StmtTag::Pragma, + Stmt::Reindex { .. } => StmtTag::Reindex, + Stmt::Release(_) => StmtTag::Release, + Stmt::Rollback { .. } => StmtTag::Rollback, + Stmt::Savepoint(_) => StmtTag::Savepoint, + + Stmt::Update { .. } => StmtTag::Update, + Stmt::Vacuum(_, _) => StmtTag::Vacuum, + }, + _ => StmtTag::Other, + } + } +} + +fn parse_query(sql: &str) -> Result, sqlite3_parser::lexer::sql::Error> { + let mut cmds = VecDeque::new(); + + let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); + loop { + match parser.next() { + Ok(Some(cmd)) => { + cmds.push_back(ParsedCmd(cmd)); + } + Ok(None) => { + break; + } + Err(e) => return Err(e), + } + } + + Ok(cmds) +} + +enum OpenTx { + Implicit, + Explicit, +} + +async fn peek_for_sslrequest( + tcp_socket: &mut TcpStream, + ssl_supported: bool, +) -> std::io::Result { + let mut ssl = false; + let mut buf = [0u8; SslRequest::BODY_SIZE]; + let mut buf = ReadBuf::new(&mut buf); + loop { + let size = poll_fn(|cx| tcp_socket.poll_peek(cx, &mut buf)).await?; + if size == 0 { + // the tcp_stream has ended + return Ok(false); + } + if size == SslRequest::BODY_SIZE { + let mut buf_ref = buf.filled(); + // skip first 4 bytes + buf_ref.get_i32(); + if buf_ref.get_i32() == SslRequest::BODY_MAGIC_NUMBER { + // the socket is sending sslrequest, read the first 8 bytes + // skip first 8 bytes + tcp_socket + .read_exact(&mut [0u8; SslRequest::BODY_SIZE]) + .await?; + // ssl configured + if ssl_supported { + ssl = true; + tcp_socket.write_all(b"S").await?; + } else { + tcp_socket.write_all(b"N").await?; + } + } + + return Ok(ssl); + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PgStartError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), +} + +pub async fn start( + agent: Agent, + pg: PgConfig, + mut tripwire: Tripwire, +) -> Result { + let server = TcpListener::bind(pg.bind_addr).await?; + let local_addr = server.local_addr()?; + + // let tmp_dir = tempfile::TempDir::new()?; + // let pg_system_path = tmp_dir.path().join("pg_system.sqlite"); + + tokio::spawn(async move { + loop { + let (mut conn, remote_addr) = match server.accept().preemptible(&mut tripwire).await { + Outcome::Completed(res) => res?, + Outcome::Preempted(_) => break, + }; + info!("accepted a conn, addr: {remote_addr}"); + + let agent = agent.clone(); + tokio::spawn(async move { + conn.set_nodelay(true)?; + let ssl = peek_for_sslrequest(&mut conn, false).await?; + trace!("SSL? {ssl}"); + + let mut framed = Framed::new( + conn, + PgWireMessageServerCodec::new(ClientInfoHolder::new(remote_addr, false)), + ); + + let msg = framed.next().await.unwrap()?; + trace!("msg: {msg:?}"); + + match msg { + PgWireFrontendMessage::Startup(startup) => { + info!("received startup message: {startup:?}"); + } + _ => { + framed + .send(PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "expected startup message".into(), + ) + .into(), + )) + .await?; + return Ok(()); + } + } + + framed.set_state(pgwire::api::PgWireConnectionState::ReadyForQuery); + + framed + .feed(PgWireBackendMessage::Authentication( + pgwire::messages::startup::Authentication::Ok, + )) + .await?; + + framed + .feed(PgWireBackendMessage::ParameterStatus(ParameterStatus::new( + "server_version".into(), + "14.0.0".into(), + ))) + .await?; + + framed + .feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + READY_STATUS_IDLE, + ))) + .await?; + + framed.flush().await?; + + trace!("sent auth ok and ReadyForQuery"); + + let (front_tx, mut front_rx) = channel(1024); + let (back_tx, mut back_rx) = channel(1024); + + let (mut sink, mut stream) = framed.split(); + + let conn = agent.pool().client_dedicated().unwrap(); + trace!("opened connection"); + + let cancel = CancellationToken::new(); + + tokio::spawn({ + let back_tx = back_tx.clone(); + let cancel = cancel.clone(); + async move { + // cancel stuff if this loop breaks + let _drop_guard = cancel.drop_guard(); + + while let Some(decode_res) = stream.next().await { + let msg = match decode_res { + Ok(msg) => msg, + Err(PgWireError::IoError(io_error)) => { + debug!("postgres io error: {io_error}"); + break; + } + Err(e) => { + warn!("could not receive pg frontend message: {e}"); + // attempt to send this... + _ = back_tx.try_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + ); + break; + } + }; + + front_tx.send(msg).await?; + } + debug!("frontend stream is done"); + + Ok::<_, BoxError>(()) + } + }); + + tokio::spawn({ + let cancel = cancel.clone(); + async move { + let _drop_guard = cancel.drop_guard(); + while let Some(back) = back_rx.recv().await { + match back { + BackendResponse::Message { message, flush } => { + debug!("sending: {message:?}"); + sink.feed(message).await?; + if flush { + sink.flush().await?; + } + } + BackendResponse::Flush => { + sink.flush().await?; + } + } + } + debug!("backend stream is done"); + Ok::<_, std::io::Error>(()) + } + }); + + let int_handle = conn.get_interrupt_handle(); + tokio::spawn(async move { + cancel.cancelled().await; + int_handle.interrupt(); + }); + + block_in_place(|| { + conn.create_module("pg_type", eponymous_only_module::(), None)?; + conn.create_module("pg_range", eponymous_only_module::(), None)?; + + let schema = match compute_schema(&conn) { + Ok(schema) => schema, + Err(e) => { + error!("could not parse schema: {e}"); + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + "XX000".into(), + "could not parse database schema".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + return Ok(()); + } + }; + + let mut prepared: HashMap = HashMap::new(); + + let mut portals: HashMap = HashMap::new(); + + let mut open_tx = None; + + 'outer: while let Some(msg) = front_rx.blocking_recv() { + debug!("msg: {msg:?}"); + + match msg { + PgWireFrontendMessage::Startup(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected startup message".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::Parse(parse) => { + let name: &str = parse.name().as_deref().unwrap_or(""); + let mut cmds = match parse_query(parse.query()) { + Ok(cmds) => cmds, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + match cmds.pop_front() { + None => { + prepared.insert(name.into(), Prepared::Empty); + } + Some(parsed_cmd) => { + if !cmds.is_empty() { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + sql_state::SqlState::PROTOCOL_VIOLATION + .code() + .into(), + "only 1 command per Parse is allowed" + .into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + + trace!("parsed cmd: {parsed_cmd:#?}"); + + let prepped = match conn.prepare(parse.query()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + let mut param_types: Vec = parse + .type_oids() + .iter() + .filter_map(|oid| Type::from_oid(*oid)) + .collect(); + + if param_types.len() != prepped.parameter_count() { + param_types = parameter_types(&schema, &parsed_cmd.0) + .into_iter() + .map(|param| match param { + SqliteType::Null => unreachable!(), + SqliteType::Integer => Type::INT8, + SqliteType::Real => Type::FLOAT8, + SqliteType::Text => Type::TEXT, + SqliteType::Blob => Type::BYTEA, + }) + .collect(); + } + + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + + prepared.insert( + name.into(), + Prepared::NonEmpty { + sql: parse.query().clone(), + param_types, + fields, + tag: parsed_cmd.tag(), + }, + ); + } + } + + back_tx.blocking_send( + ( + PgWireBackendMessage::ParseComplete(ParseComplete::new()), + false, + ) + .into(), + )?; + } + PgWireFrontendMessage::Describe(desc) => { + let name = desc.name().as_deref().unwrap_or(""); + match desc.target_type() { + // statement + b'S' => match prepared.get(name) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "statement not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + } + Some(Prepared::Empty) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::NoData(NoData::new()), + false, + ) + .into(), + )?; + } + Some(Prepared::NonEmpty { + param_types, + fields, + .. + }) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ParameterDescription( + ParameterDescription::new( + param_types + .iter() + .map(|t| t.oid()) + .collect(), + ), + ), + false, + ) + .into(), + )?; + + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + false, + ) + .into(), + )?; + } + }, + // portal + b'P' => match portals.get(name) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + } + Some(Portal::Empty { .. }) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::NoData(NoData::new()), + false, + ) + .into(), + )?; + } + Some(Portal::Parsed { + stmt, + result_formats, + .. + }) => { + let mut oids = vec![]; + let mut fields = vec![]; + for (i, col) in stmt.columns().into_iter().enumerate() { + let col_type = + match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send(( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ).into())?; + continue 'outer; + } + }; + oids.push(col_type.oid()); + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + result_formats + .get(i) + .copied() + .unwrap_or(FieldFormat::Text), + )); + } + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + false, + ) + .into(), + )?; + } + }, + _ => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected describe type".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + } + } + PgWireFrontendMessage::Bind(bind) => { + let portal_name = bind + .portal_name() + .as_deref() + .map(CompactString::from) + .unwrap_or_default(); + + let stmt_name = bind.statement_name().as_deref().unwrap_or(""); + + match prepared.get(stmt_name) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "statement not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + Some(Prepared::Empty) => { + portals.insert( + portal_name, + Portal::Empty { + stmt_name: stmt_name.into(), + }, + ); + } + Some(Prepared::NonEmpty { + sql, + param_types, + tag, + .. + }) => { + let mut prepped = match conn.prepare(sql) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + trace!( + "bind params count: {}, statement params count: {}", + bind.parameters().len(), + prepped.parameter_count() + ); + + for (i, param) in bind.parameters().iter().enumerate() { + let idx = i + 1; + let b = match param { + None => { + trace!("binding idx {idx} w/ NULL"); + if let Err(e) = prepped.raw_bind_parameter( + idx, + rusqlite::types::Null, + ) { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + continue; + } + Some(b) => b, + }; + + match param_types.get(i) { + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + "missing parameter type".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + Some(param_type) => { + match param_type { + &Type::BOOL => { + let value: bool = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT2 => { + let value: i16 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT4 => { + let value: i32 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::INT8 => { + let value: i64 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::TEXT | &Type::VARCHAR => { + let value: &str = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::FLOAT4 => { + let value: f32 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::FLOAT8 => { + let value: f64 = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value}"); + prepped + .raw_bind_parameter(idx, value)?; + } + &Type::BYTEA => { + let value: &[u8] = FromSql::from_sql( + param_type, + b.as_ref(), + )?; + trace!("binding idx {idx} w/ value: {value:?}"); + prepped + .raw_bind_parameter(idx, value)?; + } + t => { + warn!("unsupported type: {t:?}"); + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!( + "unsupported type {t} at index {i}" + ), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } + } + } + } + + debug!("EXPANDED SQL: {:?}", prepped.expanded_sql()); + + portals.insert( + portal_name, + Portal::Parsed { + stmt_name: stmt_name.into(), + stmt: prepped, + result_formats: bind + .result_column_format_codes() + .iter() + .copied() + .map(FieldFormat::from) + .collect(), + tag: *tag, + }, + ); + } + } + + back_tx.blocking_send( + ( + PgWireBackendMessage::BindComplete(BindComplete::new()), + false, + ) + .into(), + )?; + } + PgWireFrontendMessage::Sync(_) => { + let ready_status = if open_tx.is_some() { + READY_STATUS_TRANSACTION_BLOCK + } else { + READY_STATUS_IDLE + }; + back_tx.blocking_send( + ( + PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new( + ready_status, + )), + true, + ) + .into(), + )?; + } + PgWireFrontendMessage::Execute(execute) => { + let name = execute.name().as_deref().unwrap_or(""); + let (prepped, result_formats, tag) = match portals.get_mut(name) { + Some(Portal::Empty { .. }) => { + trace!("empty portal"); + back_tx.blocking_send( + ( + PgWireBackendMessage::EmptyQueryResponse( + EmptyQueryResponse::new(), + ), + false, + ) + .into(), + )?; + continue; + } + Some(Portal::Parsed { + stmt, + result_formats, + tag, + .. + }) => (stmt, result_formats, tag), + None => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".into(), + "portal not found".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + trace!("non-empty portal!"); + + // TODO: maybe we don't need to recompute this... + let mut fields = vec![]; + for (i, col) in prepped.columns().into_iter().enumerate() { + let col_type = + match name_to_type(col.decl_type().unwrap_or("text")) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + result_formats.get(i).copied().unwrap_or(FieldFormat::Text), + )); + } + + trace!("fields: {fields:?}"); + + let schema = Arc::new(fields); + + let mut rows = prepped.raw_query(); + let ncols = schema.len(); + + let max_rows = *execute.max_rows(); + let max_rows = if max_rows == 0 { + usize::MAX + } else { + max_rows as usize + }; + let mut count = 0; + + trace!("starting loop"); + + loop { + if count >= max_rows { + trace!("attained max rows"); + // forget the Rows iterator here so as to not reset the statement! + std::mem::forget(rows); + back_tx.blocking_send( + ( + PgWireBackendMessage::PortalSuspended( + PortalSuspended::new(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + let row = match rows.next() { + Ok(Some(row)) => { + trace!("got a row: {row:?}"); + row + } + Ok(None) => { + trace!("done w/ rows"); + break; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + count += 1; + let mut encoder = DataRowEncoder::new(schema.clone()); + for idx in 0..ncols { + let data = row.get_ref_unwrap::(idx); + match data { + ValueRef::Null => { + encoder.encode_field(&None::).unwrap() + } + ValueRef::Integer(i) => { + encoder.encode_field(&i).unwrap(); + } + ValueRef::Real(f) => { + encoder.encode_field(&f).unwrap(); + } + ValueRef::Text(t) => { + encoder + .encode_field( + &String::from_utf8_lossy(t).as_ref(), + ) + .unwrap(); + } + ValueRef::Blob(b) => { + encoder.encode_field(&b).unwrap(); + } + } + } + match encoder.finish() { + Ok(data_row) => { + back_tx.blocking_send( + (PgWireBackendMessage::DataRow(data_row), false) + .into(), + )?; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } + } + + trace!("done w/ rows, computing tag: {tag:?}"); + + let tag = if tag.returns_num_rows() { + tag.tag(Some(count)) + } else if tag.returns_rows_affected() { + tag.tag(Some(conn.changes() as usize)) + } else { + tag.tag(None) + }; + + // done! + back_tx.blocking_send( + (PgWireBackendMessage::CommandComplete(tag.into()), true) + .into(), + )?; + } + PgWireFrontendMessage::Query(query) => { + let trimmed = query.query().trim_matches(';'); + + let parsed_query = match parse_query(trimmed) { + Ok(q) => q, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + }; + + if parsed_query.is_empty() { + back_tx.blocking_send( + ( + PgWireBackendMessage::EmptyQueryResponse( + EmptyQueryResponse::new(), + ), + false, + ) + .into(), + )?; + + let ready_status = if open_tx.is_some() { + ReadyForQuery::new(READY_STATUS_TRANSACTION_BLOCK) + } else { + ReadyForQuery::new(READY_STATUS_IDLE) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::ReadyForQuery(ready_status), true) + .into(), + )?; + continue; + } + + for cmd in parsed_query.into_iter() { + // need to start an implicit transaction + if open_tx.is_none() && !cmd.is_begin() { + if let Err(e) = conn.execute_batch("BEGIN") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + trace!("started IMPLICIT tx"); + open_tx = Some(OpenTx::Implicit); + } + + // close the current implement tx first + if matches!(open_tx, Some(OpenTx::Implicit)) && cmd.is_begin() { + trace!("committing IMPLICIT tx"); + open_tx = None; + + if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + + continue 'outer; + } + trace!("committed IMPLICIT tx"); + } + + let count = if cmd.is_commit() { + open_tx = None; + + if let Err(e) = + handle_commit(&agent, &conn, &cmd.0.to_string()) + { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + + continue 'outer; + } + None + } else { + let mut prepped = match conn.prepare(&cmd.0.to_string()) { + Ok(prepped) => prepped, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + + let mut fields = vec![]; + for col in prepped.columns() { + let col_type = match name_to_type( + col.decl_type().unwrap_or("text"), + ) { + Ok(t) => t, + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + e.into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + }; + fields.push(FieldInfo::new( + col.name().to_string(), + None, + None, + col_type, + FieldFormat::Text, + )); + } + + back_tx.blocking_send( + ( + PgWireBackendMessage::RowDescription( + RowDescription::new( + fields.iter().map(Into::into).collect(), + ), + ), + true, + ) + .into(), + )?; + + let schema = Arc::new(fields); + + let mut rows = prepped.raw_query(); + let ncols = schema.len(); + + let mut count = 0; + while let Ok(Some(row)) = rows.next() { + count += 1; + let mut encoder = DataRowEncoder::new(schema.clone()); + for idx in 0..ncols { + let data = row.get_ref_unwrap::(idx); + match data { + ValueRef::Null => { + encoder.encode_field(&None::).unwrap() + } + ValueRef::Integer(i) => { + encoder.encode_field(&i).unwrap(); + } + ValueRef::Real(f) => { + encoder.encode_field(&f).unwrap(); + } + ValueRef::Text(t) => { + encoder + .encode_field( + &String::from_utf8_lossy(t) + .as_ref(), + ) + .unwrap(); + } + ValueRef::Blob(b) => { + encoder.encode_field(&b).unwrap(); + } + } + } + match encoder.finish() { + Ok(data_row) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::DataRow(data_row), + false, + ) + .into(), + )?; + } + Err(e) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue 'outer; + } + } + } + Some(count) + }; + + let tag = cmd.tag(); + + let tag = if tag.returns_num_rows() { + tag.tag(count) + } else if tag.returns_rows_affected() { + tag.tag(Some(conn.changes() as usize)) + } else { + tag.tag(None) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::CommandComplete(tag.into()), true) + .into(), + )?; + + if cmd.is_begin() { + trace!("setting EXPLICIT tx"); + // explicit tx + open_tx = Some(OpenTx::Explicit) + } else if cmd.is_rollback() || cmd.is_commit() { + trace!("clearing current open tx"); + // if this was a rollback, remove the current open tx + open_tx = None; + } + } + + // automatically commit an implicit tx + if matches!(open_tx, Some(OpenTx::Implicit)) { + trace!("committing IMPLICIT tx"); + open_tx = None; + + if let Err(e) = handle_commit(&agent, &conn, "COMMIT") { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + trace!("committed IMPLICIT tx"); + } + + let ready_status = if open_tx.is_some() { + ReadyForQuery::new(READY_STATUS_TRANSACTION_BLOCK) + } else { + ReadyForQuery::new(READY_STATUS_IDLE) + }; + + back_tx.blocking_send( + (PgWireBackendMessage::ReadyForQuery(ready_status), true) + .into(), + )?; + } + PgWireFrontendMessage::Terminate(_) => { + break; + } + + PgWireFrontendMessage::PasswordMessageFamily(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "PasswordMessage is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::Close(close) => { + let name = close.name().as_deref().unwrap_or(""); + match close.target_type() { + // statement + b'S' => { + if prepared.remove(name).is_some() { + portals.retain(|_, portal| portal.stmt_name() != name); + } + back_tx.blocking_send( + ( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ) + .into(), + )?; + continue; + } + // portal + b'P' => { + portals.remove(name); + back_tx.blocking_send( + ( + PgWireBackendMessage::CloseComplete( + CloseComplete::new(), + ), + true, + ) + .into(), + )?; + } + _ => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "FATAL".into(), + SqlState::PROTOCOL_VIOLATION.code().into(), + "unexpected Close target_type".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + } + } + PgWireFrontendMessage::Flush(_) => { + back_tx.blocking_send(BackendResponse::Flush)?; + } + PgWireFrontendMessage::CopyData(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyData is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::CopyFail(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyFail is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + PgWireFrontendMessage::CopyDone(_) => { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".into(), + "XX000".to_owned(), + "CopyDone is not implemented".into(), + ) + .into(), + ), + true, + ) + .into(), + )?; + continue; + } + } + } + + Ok::<_, BoxError>(()) + })?; + + Ok::<_, BoxError>(()) + }); + } + + info!("postgres server done"); + + Ok::<_, BoxError>(()) + }); + + Ok(PgServer { local_addr }) +} + +#[allow(clippy::result_large_err)] +fn name_to_type(name: &str) -> Result { + match name.to_uppercase().as_ref() { + "ANY" => Ok(Type::ANY), + "INT" | "INTEGER" => Ok(Type::INT8), + "VARCHAR" => Ok(Type::VARCHAR), + "TEXT" => Ok(Type::TEXT), + "BINARY" | "BLOB" => Ok(Type::BYTEA), + "FLOAT" => Ok(Type::FLOAT8), + _ => Err(ErrorInfo::new( + "ERROR".to_owned(), + "42846".to_owned(), + format!("Unsupported data type: {name}"), + )), + } +} + +fn handle_commit(agent: &Agent, conn: &Connection, commit_stmt: &str) -> rusqlite::Result<()> { + let actor_id = agent.actor_id(); + + let ts = Timestamp::from(agent.clock().new_timestamp()); + + let db_version: i64 = conn + .prepare_cached("SELECT crsql_next_db_version()")? + .query_row((), |row| row.get(0))?; + + let has_changes: bool = conn + .prepare_cached( + "SELECT EXISTS(SELECT 1 FROM crsql_changes WHERE site_id IS NULL AND db_version = ?);", + )? + .query_row([db_version], |row| row.get(0))?; + + if !has_changes { + conn.execute_batch(commit_stmt)?; + return Ok(()); + } + + let booked = { + agent + .bookie() + .blocking_write("handle_write_tx(for_actor)") + .for_actor(actor_id) + }; + + let last_seq: i64 = conn + .prepare_cached( + "SELECT MAX(seq) FROM crsql_changes WHERE site_id IS NULL AND db_version = ?", + )? + .query_row([db_version], |row| row.get(0))?; + + let mut book_writer = booked.blocking_write("handle_write_tx(book_writer)"); + + let last_version = book_writer.last().unwrap_or_default(); + trace!("last_version: {last_version}"); + let version = last_version + 1; + trace!("version: {version}"); + + conn.prepare_cached( + r#" + INSERT INTO __corro_bookkeeping (actor_id, start_version, db_version, last_seq, ts) + VALUES (:actor_id, :start_version, :db_version, :last_seq, :ts); + "#, + )? + .execute(named_params! { + ":actor_id": actor_id, + ":start_version": version, + ":db_version": db_version, + ":last_seq": last_seq, + ":ts": ts + })?; + + debug!(%actor_id, %version, %db_version, "inserted local bookkeeping row!"); + + conn.execute_batch(commit_stmt)?; + + trace!("committed tx, db_version: {db_version}, last_seq: {last_seq:?}"); + + book_writer.insert( + version, + KnownDbVersion::Current { + db_version, + last_seq, + ts, + }, + ); + + drop(book_writer); + + spawn_counted({ + let agent = agent.clone(); + async move { + let conn = agent.pool().read().await?; + block_in_place(|| agent.process_subs_by_db_version(&conn, db_version)); + Ok::<_, SqlitePoolError>(()) + } + }); + + Ok(()) +} + +fn compute_schema(conn: &Connection) -> Result { + let mut dump = String::new(); + + let tables: HashMap = conn + .prepare(r#"SELECT name, sql FROM sqlite_schema WHERE type = "table" AND name IS NOT NULL AND sql IS NOT NULL ORDER BY tbl_name"#)? + .query_map((), |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })? + .collect::>()?; + + for sql in tables.values() { + dump.push_str(sql.as_str()); + dump.push(';'); + } + + let indexes: HashMap = conn + .prepare(r#"SELECT name, sql FROM sqlite_schema WHERE type = "index" AND name IS NOT NULL AND sql IS NOT NULL ORDER BY tbl_name"#)? + .query_map((), |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })? + .collect::>()?; + + for sql in indexes.values() { + dump.push_str(sql.as_str()); + dump.push(';'); + } + + parse_sql(dump.as_str()) +} + +fn is_param(expr: &Expr) -> bool { + matches!(expr, Expr::Variable(_)) +} + +enum SqliteNameRef<'a> { + Id(&'a Id), + Name(&'a Name), + Qualified(&'a Name, &'a Name), + DoublyQualified(&'a Name, &'a Name, &'a Name), +} + +impl<'a> SqliteNameRef<'a> { + fn to_owned(&self) -> SqliteName { + match self { + SqliteNameRef::Id(id) => SqliteName::Id((*id).clone()), + SqliteNameRef::Name(name) => SqliteName::Name((*name).clone()), + SqliteNameRef::Qualified(n0, n1) => SqliteName::Qualified((*n0).clone(), (*n1).clone()), + SqliteNameRef::DoublyQualified(n0, n1, n2) => { + SqliteName::DoublyQualified((*n0).clone(), (*n1).clone(), (*n2).clone()) + } + } + } +} + +#[derive(Clone, Debug)] +enum SqliteName { + Id(Id), + Name(Name), + Qualified(Name, Name), + DoublyQualified(Name, Name, Name), +} + +fn expr_to_name(expr: &Expr) -> Option { + match expr { + Expr::Id(id) => Some(SqliteNameRef::Id(id)), + Expr::Name(name) => Some(SqliteNameRef::Name(name)), + Expr::Qualified(n0, n1) => Some(SqliteNameRef::Qualified(n0, n1)), + Expr::DoublyQualified(n0, n1, n2) => Some(SqliteNameRef::DoublyQualified(n0, n1, n2)), + _ => None, + } +} + +// determines the type of a literal type if any +// fn literal_type(expr: &Expr) -> Option { +// match expr { +// Expr::Literal(lit) => match lit { +// Literal::Numeric(num) => { +// if num.parse::().is_ok() { +// Some(SqliteType::Integer) +// } else if num.parse::().is_ok() { +// Some(SqliteType::Real) +// } else { +// // this should be unreachable... +// None +// } +// } +// Literal::String(_) => Some(SqliteType::Text), +// Literal::Blob(_) => Some(SqliteType::Blob), +// Literal::Keyword(keyword) => { +// // TODO: figure out what this is... +// warn!("got a keyword: {keyword}"); +// None +// } +// Literal::Null => Some(SqliteType::Null), +// Literal::CurrentDate | Literal::CurrentTime | Literal::CurrentTimestamp => { +// // TODO: make this configurable at connection time or something +// Some(SqliteType::Text) +// } +// }, +// _ => None, +// } +// } + +fn handle_lhs_rhs(lhs: &Expr, rhs: &Expr) -> Option { + match ( + (expr_to_name(lhs), is_param(lhs)), + (expr_to_name(rhs), is_param(rhs)), + ) { + ((Some(name), _), (_, true)) | ((_, true), (Some(name), _)) => Some(name.to_owned()), + _ => None, + } +} + +fn extract_param( + schema: &Schema, + expr: &Expr, + tables: &HashMap, + params: &mut Vec, +) { + match expr { + // expr BETWEEN expr AND expr + Expr::Between { + lhs: _, + start: _, + end: _, + not: _, + } => {} + + // expr operator expr + Expr::Binary(lhs, _, rhs) => { + if let Some(name) = handle_lhs_rhs(lhs, rhs) { + match name { + // not aliased! + SqliteName::Id(id) => { + // find the first one to match + for (_, table) in tables.iter() { + if let Some(col) = table.columns.get(&id.0) { + params.push(col.sql_type); + break; + } + } + } + SqliteName::Name(_) => {} + SqliteName::Qualified(tbl_name, col_name) + | SqliteName::DoublyQualified(_, tbl_name, col_name) => { + trace!("looking tbl {} for col {}", tbl_name.0, col_name.0); + if let Some(table) = tables.get(&tbl_name.0) { + trace!("found table! {}", table.name); + let col_name = if col_name.0.starts_with('"') { + rem_first_and_last(&col_name.0) + } else { + &col_name.0 + }; + + if let Some(col) = table.columns.get(col_name) { + params.push(col.sql_type); + } + } + } + } + } + } + + // CASE expr [WHEN expr THEN expr, ..., ELSE expr] + Expr::Case { + base: _, + when_then_pairs: _, + else_expr: _, + } => {} + + // CAST ( expr AS type-name ) + Expr::Cast { + expr: _, + type_name: _, + } => {} + + // expr COLLATE collation-name + Expr::Collate(_, _) => {} + + // schema-name.table-name.column-name + Expr::DoublyQualified(_, _, _) => {} + + // EXISTS ( select ) + Expr::Exists(select) => handle_select(schema, select, params), + + // function-name ( [DISTINCT] expr, ... ) filter-clause over-clause + Expr::FunctionCall { + name: _, + distinctness: _, + args: _, + filter_over: _, + } => {} + + Expr::FunctionCallStar { + name: _, + filter_over: _, + } => {} + + // id + Expr::Id(_) => {} + + // expr IN ( expr, ... ) + Expr::InList { lhs, not: _, rhs } => { + if let Some(rhs) = rhs { + for expr in rhs.iter() { + if let Some(name) = handle_lhs_rhs(lhs, expr) { + trace!("HANDLED LHS RHS: {name:?}"); + match name { + // not aliased! + SqliteName::Id(id) => { + // find the first one to match + for (_, table) in tables.iter() { + if let Some(col) = table.columns.get(&id.0) { + params.push(col.sql_type); + break; + } + } + } + SqliteName::Name(_) => {} + SqliteName::Qualified(tbl_name, col_name) + | SqliteName::DoublyQualified(_, tbl_name, col_name) => { + if let Some(table) = tables.get(&tbl_name.0) { + if let Some(col) = table.columns.get(&col_name.0) { + params.push(col.sql_type); + } + } + } + } + } + } + } + } + + // expr IN ( select ) + Expr::InSelect { + lhs: _, + not: _, + rhs, + } => { + // TODO: check LHS here + handle_select(schema, rhs.as_ref(), params); + } + + // expr IN schema-name.table-name | schema-name.table-function ( expr, ... ) + Expr::InTable { + lhs: _, + not: _, + rhs: _, + args: _, + } => {} + + // expr IS NULL + Expr::IsNull(_) => {} + + // expr [NOT] LIKE | GLOB | REGEXP | MATCH expr + Expr::Like { + lhs: _, + not: _, + op: _, + rhs: _, + escape: _, + } => {} + + // NULL | integer | float | text | blob + Expr::Literal(_) => { + // nothing to do + } + + // TODO: + Expr::Name(_) => {} + + // expr NOT NULL + Expr::NotNull(_) => {} + + // ( expr, ... ) + Expr::Parenthesized(exprs) => { + for expr in exprs.iter() { + extract_param(schema, expr, tables, params) + } + } + + // schema-name.table-name + Expr::Qualified(_, _) => {} + + // RAISE ( IGNORE | ROLLBACK | ABORT | FAIL [ error ] ) + Expr::Raise(_, _) => {} + + // SELECT + Expr::Subquery(select) => handle_select(schema, select, params), + + // NOT | ~ | - | + expr + Expr::Unary(_, _) => {} + + // ? | $ | : + Expr::Variable(_) => {} + } +} + +fn rem_first_and_last(value: &str) -> &str { + let mut chars = value.chars(); + chars.next(); + chars.next_back(); + chars.as_str() +} + +fn handle_select(schema: &Schema, select: &Select, params: &mut Vec) { + let tables = match &select.body.select { + OneSelect::Select { + columns, + from, + where_clause, + distinctness: _, + group_by: _, + window_clause: _, + } => { + let tables = if let Some(from) = from { + let tables = handle_from(schema, from, params); + if let Some(where_clause) = where_clause { + trace!("WHERE CLAUSE: {where_clause:?}"); + extract_param(schema, where_clause, &tables, params); + } + tables + } else { + HashMap::new() + }; + for col in columns.iter() { + if let ResultColumn::Expr(expr, _) = col { + // TODO: check against table if we can... + if is_param(expr) { + params.push(SqliteType::Text); + } + } + } + tables + } + OneSelect::Values(values_values) => { + for values in values_values.iter() { + for value in values.iter() { + if is_param(value) { + params.push(SqliteType::Text); + } + } + } + HashMap::new() + } + }; + if let Some(limit) = &select.limit { + if is_param(&limit.expr) { + trace!("limit was a param (variable), pushing Integer type"); + params.push(SqliteType::Integer); + } else { + extract_param(schema, &limit.expr, &tables, params); + } + if let Some(offset) = &limit.offset { + if is_param(offset) { + trace!("offset was a param (variable), pushing Integer type"); + params.push(SqliteType::Integer); + } else { + extract_param(schema, offset, &tables, params); + } + } + } +} + +fn handle_from<'a>( + schema: &'a Schema, + from: &FromClause, + params: &mut Vec, +) -> HashMap { + let mut tables: HashMap = HashMap::new(); + if let Some(select) = from.select.as_deref() { + match select { + SelectTable::Table(qname, maybe_alias, _) => { + let actual_tbl_name = if qname.name.0.starts_with('"') { + rem_first_and_last(&qname.name.0) + } else { + &qname.name.0 + }; + + if let Some(table) = schema.tables.get(actual_tbl_name) { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) | As::Elided(name) => name.0.clone(), + }; + tables.insert(alias, table); + } else { + tables.insert(table.name.clone(), table); + } + } + } + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(select, _) => { + handle_select(schema, select, params); + } + SelectTable::Sub(_, _) => {} + } + } + if let Some(joins) = &from.joins { + for join in joins.iter() { + match &join.table { + SelectTable::Table(qname, maybe_alias, _) => { + let actual_tbl_name = if qname.name.0.starts_with('"') { + rem_first_and_last(&qname.name.0) + } else { + &qname.name.0 + }; + + if let Some(table) = schema.tables.get(actual_tbl_name) { + if let Some(alias) = maybe_alias { + let alias = match alias { + As::As(name) | As::Elided(name) => name.0.clone(), + }; + tables.insert(alias, table); + } else { + tables.insert(table.name.clone(), table); + } + } + } + SelectTable::TableCall(_, _, _) => {} + SelectTable::Select(select, _) => { + handle_select(schema, select, params); + } + SelectTable::Sub(_, _) => {} + } + } + } + tables +} + +fn parameter_types(schema: &Schema, cmd: &Cmd) -> Vec { + let mut params = vec![]; + + if let Cmd::Stmt(stmt) = cmd { + match stmt { + Stmt::Select(select) => handle_select(schema, select, &mut params), + Stmt::Delete { + tbl_name, + where_clause: Some(where_clause), + .. + } => { + let mut tables = HashMap::new(); + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + extract_param(schema, where_clause, &tables, &mut params); + } + Stmt::Insert { + tbl_name, + columns, + body, + .. + } => { + trace!("GOT AN INSERT TO {tbl_name:?} on columns: {columns:?} w/ body: {body:?}"); + if let Some(table) = schema.tables.get(&tbl_name.name.0) { + match body { + InsertBody::Select(select, _) => { + if let OneSelect::Values(values_values) = &select.body.select { + for values in values_values.iter() { + for (i, expr) in values.iter().enumerate() { + if is_param(expr) { + // specified columns + let col = if let Some(columns) = columns { + columns + .get(i) + .and_then(|name| table.columns.get(&name.0)) + } else { + table.columns.get_index(i).map(|(_name, col)| col) + }; + if let Some(col) = col { + params.push(col.sql_type); + } + } + } + } + } else { + handle_select(schema, select, &mut params) + } + } + InsertBody::DefaultValues => { + // nothing to do! + } + } + } + } + Stmt::Update { + with: _, + or_conflict: _, + tbl_name, + indexed: _, + sets: _, + from, + where_clause, + returning: _, + order_by: _, + limit: _, + } => { + let mut tables = if let Some(from) = from { + handle_from(schema, from, &mut params) + } else { + Default::default() + }; + if let Some(tbl) = schema.tables.get(&tbl_name.name.0) { + tables.insert(tbl_name.name.0.clone(), tbl); + } + if let Some(where_clause) = where_clause { + trace!("WHERE CLAUSE: {where_clause:?}"); + extract_param(schema, where_clause, &tables, &mut params); + } + } + _ => { + // do nothing, there can't be bound params here! + } + } + } + + params +} + +#[cfg(test)] +mod tests { + use corro_tests::launch_test_agent; + use spawn::wait_for_all_pending_handles; + use tokio_postgres::NoTls; + use tripwire::Tripwire; + + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn test_pg() -> Result<(), BoxError> { + _ = tracing_subscriber::fmt::try_init(); + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); + + let ta = launch_test_agent(|builder| builder.build(), tripwire.clone()).await?; + + let server = start( + ta.agent.clone(), + PgConfig { + bind_addr: "127.0.0.1:0".parse()?, + }, + tripwire, + ) + .await?; + + let conn_str = format!( + "host={} port={} user=testuser", + server.local_addr.ip(), + server.local_addr.port() + ); + + { + let (mut client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + // let (mut client, client_conn) = + // tokio_postgres::connect("host=localhost port=5432 user=jerome", NoTls).await?; + println!("client is ready!"); + tokio::spawn(client_conn); + + println!("before prepare"); + let stmt = client.prepare("SELECT 1").await?; + println!( + "after prepare: params: {:?}, columns: {:?}", + stmt.params(), + stmt.columns() + ); + + println!("before query"); + let rows = client.query(&stmt, &[]).await?; + + println!("rows count: {}", rows.len()); + for row in rows { + println!("ROW!!! {row:?}"); + } + + println!("before execute"); + let affected = client + .execute("INSERT INTO tests VALUES (1,2)", &[]) + .await?; + println!("after execute, affected: {affected}"); + + let row = client.query_one("SELECT * FROM crsql_changes", &[]).await?; + println!("CHANGE ROW: {row:?}"); + + client + .batch_execute("SELECT 1; SELECT 2; SELECT 3;") + .await?; + println!("after batch exec"); + + client.batch_execute("SELECT 1; BEGIN; SELECT 3;").await?; + println!("after batch exec 2"); + + client.batch_execute("SELECT 3; COMMIT; SELECT 3;").await?; + println!("after batch exec 3"); + + let tx = client.transaction().await?; + println!("after begin I assume"); + let res = tx + .execute( + "INSERT INTO tests VALUES ($1, $2)", + &[&2i64, &"hello world"], + ) + .await?; + println!("res (rows affected): {res}"); + let res = tx + .execute( + "INSERT INTO tests2 VALUES ($1, $2)", + &[&2i64, &"hello world 2"], + ) + .await?; + println!("res (rows affected): {res}"); + tx.commit().await?; + println!("after commit"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id = ?", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT * FROM tests t WHERE t.id IN (?)", &[&2i64]) + .await?; + println!("ROW: {row:?}"); + + let row = client + .query_one("SELECT t.id, t.text, t2.text as t2text FROM tests t LEFT JOIN tests2 t2 WHERE t.id = ? LIMIT ?", &[&2i64, &1i64]) + .await?; + println!("ROW: {row:?}"); + + println!("t.id: {:?}", row.try_get::<_, i64>(0)); + println!("t.text: {:?}", row.try_get::<_, String>(1)); + println!("t2text: {:?}", row.try_get::<_, String>(2)); + } + + tripwire_tx.send(()).await.ok(); + tripwire_worker.await; + wait_for_all_pending_handles().await; + + Ok(()) + } +} diff --git a/crates/corro-pg/src/sql_state.rs b/crates/corro-pg/src/sql_state.rs new file mode 100644 index 00000000..4ddf16ad --- /dev/null +++ b/crates/corro-pg/src/sql_state.rs @@ -0,0 +1,1336 @@ +/// A SQLSTATE error code +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct SqlState(Inner); + +impl SqlState { + /// Returns the error code corresponding to the `SqlState`. + pub fn code(&self) -> &str { + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + } + } + + /// 00000 + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); + + /// 01000 + pub const WARNING: SqlState = SqlState(Inner::E01000); + + /// 0100C + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); + + /// 01008 + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); + + /// 01003 + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); + + /// 01007 + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); + + /// 01006 + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); + + /// 01004 + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); + + /// 01P01 + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); + + /// 02000 + pub const NO_DATA: SqlState = SqlState(Inner::E02000); + + /// 02001 + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); + + /// 03000 + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); + + /// 08000 + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); + + /// 08003 + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); + + /// 08006 + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); + + /// 08001 + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); + + /// 08004 + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); + + /// 08007 + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); + + /// 08P01 + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); + + /// 09000 + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); + + /// 0A000 + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); + + /// 0B000 + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); + + /// 0F000 + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); + + /// 0F001 + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); + + /// 0L000 + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); + + /// 0LP01 + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); + + /// 0P000 + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); + + /// 0Z000 + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); + + /// 0Z002 + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); + + /// 20000 + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); + + /// 21000 + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); + + /// 22000 + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); + + /// 2202E + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 2202E + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 22021 + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); + + /// 22008 + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); + + /// 22008 + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); + + /// 22012 + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); + + /// 22005 + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); + + /// 2200B + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); + + /// 22022 + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); + + /// 22015 + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); + + /// 2201E + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); + + /// 22014 + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); + + /// 22016 + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); + + /// 2201F + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); + + /// 2201G + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); + + /// 22018 + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); + + /// 22007 + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); + + /// 22019 + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); + + /// 2200D + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); + + /// 22025 + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); + + /// 22P06 + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); + + /// 22010 + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); + + /// 22023 + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); + + /// 2201B + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); + + /// 2201W + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); + + /// 2201X + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); + + /// 2202H + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); + + /// 2202G + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); + + /// 22009 + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); + + /// 2200C + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); + + /// 2200G + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); + + /// 22004 + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); + + /// 22002 + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); + + /// 22003 + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); + + /// 2200H + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); + + /// 22026 + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); + + /// 22001 + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); + + /// 22011 + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); + + /// 22027 + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); + + /// 22024 + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); + + /// 2200F + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); + + /// 22P01 + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); + + /// 22P02 + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); + + /// 22P03 + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); + + /// 22P04 + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); + + /// 22P05 + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); + + /// 2200L + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); + + /// 2200M + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); + + /// 2200N + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); + + /// 2200S + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); + + /// 2200T + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); + + /// 23000 + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); + + /// 23001 + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); + + /// 23502 + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); + + /// 23503 + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); + + /// 23505 + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); + + /// 23514 + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); + + /// 23P01 + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); + + /// 24000 + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); + + /// 25000 + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); + + /// 25001 + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); + + /// 25002 + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); + + /// 25008 + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); + + /// 25003 + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); + + /// 25004 + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); + + /// 25005 + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); + + /// 25006 + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); + + /// 25007 + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); + + /// 25P01 + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); + + /// 25P02 + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); + + /// 25P03 + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); + + /// 26000 + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); + + /// 26000 + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); + + /// 27000 + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); + + /// 28000 + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); + + /// 28P01 + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); + + /// 2B000 + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); + + /// 2BP01 + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); + + /// 2D000 + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); + + /// 2F000 + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); + + /// 2F005 + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); + + /// 2F002 + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); + + /// 2F003 + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); + + /// 2F004 + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); + + /// 34000 + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); + + /// 34000 + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); + + /// 38000 + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); + + /// 38001 + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); + + /// 38002 + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); + + /// 38003 + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); + + /// 38004 + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); + + /// 39000 + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); + + /// 39001 + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); + + /// 39004 + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); + + /// 39P01 + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); + + /// 39P02 + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); + + /// 39P03 + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); + + /// 3B000 + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); + + /// 3B001 + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); + + /// 3D000 + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); + + /// 3D000 + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); + + /// 3F000 + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); + + /// 3F000 + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); + + /// 40000 + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); + + /// 40002 + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); + + /// 40001 + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); + + /// 40003 + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); + + /// 40P01 + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); + + /// 42000 + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); + + /// 42601 + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); + + /// 42501 + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); + + /// 42846 + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); + + /// 42803 + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); + + /// 42P20 + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); + + /// 42P19 + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); + + /// 42830 + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); + + /// 42602 + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); + + /// 42622 + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); + + /// 42939 + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); + + /// 42804 + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); + + /// 42P18 + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); + + /// 42P21 + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); + + /// 42P22 + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); + + /// 42809 + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); + + /// 428C9 + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); + + /// 42703 + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); + + /// 42883 + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); + + /// 42P01 + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); + + /// 42P02 + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); + + /// 42704 + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); + + /// 42701 + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); + + /// 42P03 + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); + + /// 42P04 + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); + + /// 42723 + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); + + /// 42P05 + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); + + /// 42P06 + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); + + /// 42P07 + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); + + /// 42712 + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); + + /// 42710 + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); + + /// 42702 + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); + + /// 42725 + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); + + /// 42P08 + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); + + /// 42P09 + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); + + /// 42P10 + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); + + /// 42611 + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); + + /// 42P11 + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); + + /// 42P12 + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); + + /// 42P13 + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); + + /// 42P14 + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); + + /// 42P15 + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); + + /// 42P16 + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); + + /// 42P17 + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); + + /// 44000 + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); + + /// 53000 + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); + + /// 53100 + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); + + /// 53200 + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); + + /// 53300 + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); + + /// 53400 + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); + + /// 54000 + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); + + /// 54001 + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); + + /// 54011 + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); + + /// 54023 + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); + + /// 55000 + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); + + /// 55006 + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); + + /// 55P02 + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); + + /// 55P03 + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); + + /// 55P04 + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); + + /// 57000 + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); + + /// 57014 + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); + + /// 57P01 + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); + + /// 57P02 + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); + + /// 57P03 + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); + + /// 57P04 + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); + + /// 58000 + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); + + /// 58030 + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); + + /// 58P01 + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); + + /// 58P02 + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); + + /// 72000 + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); + + /// F0000 + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); + + /// F0001 + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); + + /// HV000 + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); + + /// HV005 + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); + + /// HV002 + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); + + /// HV010 + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); + + /// HV021 + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); + + /// HV024 + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); + + /// HV007 + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); + + /// HV008 + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); + + /// HV004 + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); + + /// HV006 + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); + + /// HV091 + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); + + /// HV00B + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); + + /// HV00C + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); + + /// HV00D + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); + + /// HV090 + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); + + /// HV00A + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); + + /// HV009 + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); + + /// HV014 + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); + + /// HV001 + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); + + /// HV00P + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); + + /// HV00J + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); + + /// HV00K + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); + + /// HV00Q + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); + + /// HV00R + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); + + /// HV00L + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); + + /// HV00M + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); + + /// HV00N + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); + + /// P0000 + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); + + /// P0001 + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); + + /// P0002 + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); + + /// P0003 + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); + + /// P0004 + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); + + /// XX000 + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); + + /// XX001 + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); + + /// XX002 + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, +} diff --git a/crates/corro-pg/src/vtab/mod.rs b/crates/corro-pg/src/vtab/mod.rs new file mode 100644 index 00000000..d5ce8e87 --- /dev/null +++ b/crates/corro-pg/src/vtab/mod.rs @@ -0,0 +1,2 @@ +pub mod pg_range; +pub mod pg_type; diff --git a/crates/corro-pg/src/vtab/pg_range.rs b/crates/corro-pg/src/vtab/pg_range.rs new file mode 100644 index 00000000..cc702a42 --- /dev/null +++ b/crates/corro-pg/src/vtab/pg_range.rs @@ -0,0 +1,92 @@ +use std::{marker::PhantomData, os::raw::c_int}; + +use rusqlite::vtab::{ + sqlite3_vtab, sqlite3_vtab_cursor, IndexInfo, VTab, VTabConnection, VTabCursor, Values, +}; + +#[repr(C)] +pub struct PgRangeTable { + /// Base class. Must be first + base: sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for PgRangeTable { + type Aux = (); + type Cursor = PgRangeTableCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> rusqlite::Result<(String, PgRangeTable)> { + let vtab = PgRangeTable { + base: sqlite3_vtab::default(), + }; + + for arg in args { + println!("arg {:?}", std::str::from_utf8(arg)); + } + + Ok(( + "CREATE TABLE pg_range ( + rngtypid INTEGER, + rngsubtype INTEGER, + rngmultitypid INTEGER, + rngcollation INTEGER, + rngsubopc INTEGER, + rngcanonical TEXT, + rngsubdiff TEXT + )" + .into(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> rusqlite::Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab mut self) -> rusqlite::Result> { + Ok(PgRangeTableCursor::default()) + } +} + +#[derive(Default)] +#[repr(C)] +pub struct PgRangeTableCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab PgRangeTable>, +} + +unsafe impl VTabCursor for PgRangeTableCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> rusqlite::Result<()> { + self.row_id = 1; + Ok(()) + } + + fn next(&mut self) -> rusqlite::Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + true // no rows... + } + + fn column(&self, _ctx: &mut rusqlite::vtab::Context, _col: c_int) -> rusqlite::Result<()> { + Ok(()) + } + + fn rowid(&self) -> rusqlite::Result { + Ok(self.row_id) + } +} diff --git a/crates/corro-pg/src/vtab/pg_type.rs b/crates/corro-pg/src/vtab/pg_type.rs new file mode 100644 index 00000000..9bbecee7 --- /dev/null +++ b/crates/corro-pg/src/vtab/pg_type.rs @@ -0,0 +1,324 @@ +use std::{marker::PhantomData, os::raw::c_int}; + +use postgres_types::Type; +use rusqlite::vtab::{ + sqlite3_vtab, sqlite3_vtab_cursor, IndexInfo, VTab, VTabConnection, VTabCursor, Values, +}; + +#[repr(C)] +pub struct PgTypeTable { + /// Base class. Must be first + base: sqlite3_vtab, +} + +unsafe impl<'vtab> VTab<'vtab> for PgTypeTable { + type Aux = (); + type Cursor = PgTypeTableCursor<'vtab>; + + fn connect( + _: &mut VTabConnection, + _aux: Option<&()>, + args: &[&[u8]], + ) -> rusqlite::Result<(String, PgTypeTable)> { + let vtab = PgTypeTable { + base: sqlite3_vtab::default(), + }; + + for arg in args { + println!("arg {:?}", std::str::from_utf8(arg)); + } + + Ok(( + "CREATE TABLE pg_type ( + oid INTEGER, + typname TEXT, + typnamespace INTEGER, + typowner INTEGER, + typlen INTEGER, + typbyval INTEGER, + typtype TEXT, + typcategory TEXT, + typispreferred INTEGER, + typisdefined INTEGER, + typdelim TEXT, + typrelid INTEGER, + typelem INTEGER, + typarray INTEGER, + typinput TEXT, + typoutput TEXT, + typreceive TEXT, + typsend TEXT, + typmodin TEXT, + typmodout TEXT, + typanalyze TEXT, + typalign TEXT, + typstorage TEXT, + typnotnull INTEGER, + typbasetype INTEGER, + typtypmod INTEGER, + typndims INTEGER, + typcollation INTEGER, + typdefaultbin TEXT, + typdefault TEXT, + typacl TEXT + )" + .into(), + vtab, + )) + } + + fn best_index(&self, info: &mut IndexInfo) -> rusqlite::Result<()> { + info.set_estimated_cost(1.); + Ok(()) + } + + fn open(&'vtab mut self) -> rusqlite::Result> { + Ok(PgTypeTableCursor::default()) + } +} + +#[derive(Default)] +#[repr(C)] +pub struct PgTypeTableCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + /// The rowid + row_id: i64, + phantom: PhantomData<&'vtab PgTypeTable>, +} + +struct PgType(Type); + +impl PgType { + fn oid(&self) -> u32 { + self.0.oid() + } + + fn typname(&self) -> &str { + self.0.name() + } + + fn typnamespace(&self) -> &'static str { + "11" + } + fn typowner(&self) -> &'static str { + "10" + } + fn typlen(&self) -> i16 { + match self.0 { + Type::BOOL => 1, + Type::BYTEA => -1, + Type::INT2 => 2, + Type::INT4 => 4, + Type::INT8 => 8, + Type::TEXT => -1, + Type::VARCHAR => -1, + Type::FLOAT4 => 4, + Type::FLOAT8 => 8, + _ => { + // TODO: not default... + Default::default() + } + } + } + fn typbyval(&self) -> bool { + match self.0 { + Type::BOOL => true, + Type::BYTEA => false, + Type::INT2 => true, + Type::INT4 => true, + Type::INT8 => true, + Type::TEXT => false, + Type::VARCHAR => false, + Type::FLOAT4 => true, + Type::FLOAT8 => true, + _ => { + // TODO: not default... + Default::default() + } + } + } + fn typtype(&self) -> &'static str { + "b" + } + fn typcategory(&self) -> &'static str { + match self.0 { + Type::BOOL => "B", + Type::BYTEA => "U", + Type::INT2 => "N", + Type::INT4 => "N", + Type::INT8 => "N", + Type::TEXT => "S", + Type::VARCHAR => "S", + Type::FLOAT4 => "N", + Type::FLOAT8 => "N", + _ => { + // TODO: not default... + Default::default() + } + } + } + fn typispreferred(&self) -> bool { + // TODO: not default... + Default::default() + } + fn typisdefined(&self) -> bool { + true + } + fn typdelim(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typrelid(&self) -> i64 { + 0 + } + fn typelem(&self) -> &'static str { + "0" + } + fn typarray(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typinput(&self) -> String { + format!("{}in", self.0.name()) + } + fn typoutput(&self) -> String { + format!("{}out", self.0.name()) + } + fn typreceive(&self) -> String { + format!("{}recv", self.0.name()) + } + fn typsend(&self) -> String { + format!("{}send", self.0.name()) + } + fn typmodin(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typmodout(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typanalyze(&self) -> &'static str { + "-" + } + fn typalign(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typstorage(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typnotnull(&self) -> bool { + false + } + fn typbasetype(&self) -> &'static str { + "0" + } + fn typtypmod(&self) -> i32 { + -1 + } + fn typndims(&self) -> i32 { + 0 + } + fn typcollation(&self) -> &'static str { + // TODO: not default... + Default::default() + } + fn typdefaultbin(&self) -> rusqlite::types::Null { + rusqlite::types::Null + } + fn typdefault(&self) -> Option<&'static str> { + None + } + fn typacl(&self) -> rusqlite::types::Null { + rusqlite::types::Null + } +} + +const PG_TYPES: &[PgType] = &[ + // TINY INT + PgType(Type::BOOL), + // BLOB + PgType(Type::BYTEA), + // INTS + PgType(Type::INT2), + PgType(Type::INT4), + PgType(Type::INT8), + // TEXT + PgType(Type::TEXT), + PgType(Type::VARCHAR), + // REAL + PgType(Type::FLOAT4), + PgType(Type::FLOAT8), +]; + +unsafe impl VTabCursor for PgTypeTableCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + _args: &Values<'_>, + ) -> rusqlite::Result<()> { + self.row_id = 0; + Ok(()) + } + + fn next(&mut self) -> rusqlite::Result<()> { + self.row_id += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.row_id >= PG_TYPES.len() as i64 + } + + fn column(&self, ctx: &mut rusqlite::vtab::Context, col: c_int) -> rusqlite::Result<()> { + if let Some(pg_type) = PG_TYPES.get(self.row_id as usize) { + match col { + 0 => ctx.set_result(&pg_type.oid()), + 1 => ctx.set_result(&pg_type.typname()), // pg_type.typname + 2 => ctx.set_result(&pg_type.typnamespace()), // pg_type.typnamespace + 3 => ctx.set_result(&pg_type.typowner()), // pg_type.typowner + 4 => ctx.set_result(&pg_type.typlen()), // pg_type.typlen + 5 => ctx.set_result(&pg_type.typbyval()), // pg_type.typbyval + 6 => ctx.set_result(&pg_type.typtype()), // pg_type.typtype + 7 => ctx.set_result(&pg_type.typcategory()), // pg_type.typcategory + 8 => ctx.set_result(&pg_type.typispreferred()), // pg_type.typispreferred + 9 => ctx.set_result(&pg_type.typisdefined()), // pg_type.typisdefined + 10 => ctx.set_result(&pg_type.typdelim()), // pg_type.typdelim + 11 => ctx.set_result(&pg_type.typrelid()), // pg_type.typrelid + 12 => ctx.set_result(&pg_type.typelem()), // pg_type.typelem + 13 => ctx.set_result(&pg_type.typarray()), // pg_type.typarray + 14 => ctx.set_result(&pg_type.typinput()), // pg_type.typinput + 15 => ctx.set_result(&pg_type.typoutput()), // pg_type.typoutput + 16 => ctx.set_result(&pg_type.typreceive()), // pg_type.typreceive + 17 => ctx.set_result(&pg_type.typsend()), // pg_type.typsend + 18 => ctx.set_result(&pg_type.typmodin()), // pg_type.typmodin + 19 => ctx.set_result(&pg_type.typmodout()), // pg_type.typmodout + 20 => ctx.set_result(&pg_type.typanalyze()), // pg_type.typanalyze + 21 => ctx.set_result(&pg_type.typalign()), // pg_type.typalign + 22 => ctx.set_result(&pg_type.typstorage()), // pg_type.typstorage + 23 => ctx.set_result(&pg_type.typnotnull()), // pg_type.typnotnull + 24 => ctx.set_result(&pg_type.typbasetype()), // pg_type.typbasetype + 25 => ctx.set_result(&pg_type.typtypmod()), // pg_type.typtypmod + 26 => ctx.set_result(&pg_type.typndims()), // pg_type.typndims + 27 => ctx.set_result(&pg_type.typcollation()), // pg_type.typcollation + 28 => ctx.set_result(&pg_type.typdefaultbin()), // pg_type.typdefaultbin + 29 => ctx.set_result(&pg_type.typdefault()), // pg_type.typdefault + 30 => ctx.set_result(&pg_type.typacl()), // pg_type.typacl + _ => Err(rusqlite::Error::InvalidColumnIndex(col as usize)), + } + } else { + Err(rusqlite::Error::ModuleError(format!( + "pg type out of bound (row id: {})", + self.row_id + ))) + } + } + + fn rowid(&self) -> rusqlite::Result { + Ok(self.row_id) + } +} diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index 5ff857e2..16545bf6 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -20,6 +20,10 @@ use parking_lot::RwLock; use rangemap::RangeInclusiveSet; use rusqlite::{Connection, InterruptHandle}; use serde::{Deserialize, Serialize}; +use tokio::sync::{ + OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, RwLock as TokioRwLock, + RwLockReadGuard as TokioRwLockReadGuard, RwLockWriteGuard as TokioRwLockWriteGuard, +}; use tokio::{ runtime::Handle, sync::{ @@ -27,15 +31,8 @@ use tokio::{ oneshot, Semaphore, }, }; -use tokio::{ - sync::{ - OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, RwLock as TokioRwLock, - RwLockReadGuard as TokioRwLockReadGuard, RwLockWriteGuard as TokioRwLockWriteGuard, - }, - task::block_in_place, -}; use tokio_util::sync::{CancellationToken, DropGuard}; -use tracing::{debug, error, info, Instrument}; +use tracing::{debug, error, info, trace, Instrument}; use tripwire::Tripwire; use crate::{ @@ -43,7 +40,7 @@ use crate::{ broadcast::{BroadcastInput, ChangeSource, ChangeV1, FocaInput, Timestamp}, config::Config, pubsub::MatcherHandle, - schema::NormalizedSchema, + schema::Schema, sqlite::{rusqlite_to_crsqlite, setup_conn, AttachMap, CrConn, SqlitePool, SqlitePoolError}, }; @@ -70,7 +67,7 @@ pub struct AgentConfig { pub tx_changes: Sender<(ChangeV1, ChangeSource)>, pub tx_foca: Sender, - pub schema: RwLock, + pub schema: RwLock, pub tripwire: Tripwire, } @@ -89,7 +86,7 @@ pub struct AgentInner { tx_empty: Sender<(ActorId, RangeInclusive)>, tx_changes: Sender<(ChangeV1, ChangeSource)>, tx_foca: Sender, - schema: RwLock, + schema: RwLock, limits: Limits, } @@ -174,7 +171,7 @@ impl Agent { &self.0.members } - pub fn schema(&self) -> &RwLock { + pub fn schema(&self) -> &RwLock { &self.0.schema } @@ -197,6 +194,26 @@ impl Agent { pub fn limits(&self) -> &Limits { &self.0.limits } + + pub fn process_subs_by_db_version(&self, conn: &Connection, db_version: i64) { + trace!("process subs by db version..."); + + let mut matchers_to_delete = vec![]; + + { + let matchers = self.matchers().read(); + for (id, matcher) in matchers.iter() { + if let Err(e) = matcher.process_changes_from_db_version(conn, db_version) { + error!("could not process change w/ matcher {id}, it is probably defunct! {e}"); + matchers_to_delete.push(*id); + } + } + } + + for id in matchers_to_delete { + self.matchers().write().remove(&id); + } + } } #[derive(Debug, Clone)] @@ -363,12 +380,16 @@ impl SplitPool { } #[tracing::instrument(skip(self), level = "debug")] - pub async fn dedicated(&self) -> rusqlite::Result { - block_in_place(|| { - let mut conn = rusqlite::Connection::open(&self.0.path)?; - setup_conn(&mut conn, &self.0.attachments)?; - Ok(conn) - }) + pub fn dedicated(&self) -> rusqlite::Result { + let mut conn = rusqlite::Connection::open(&self.0.path)?; + setup_conn(&mut conn, &self.0.attachments)?; + Ok(conn) + } + + #[tracing::instrument(skip(self), level = "debug")] + pub fn client_dedicated(&self) -> rusqlite::Result { + let conn = rusqlite::Connection::open(&self.0.path)?; + rusqlite_to_crsqlite(conn) } // get a high priority write connection (e.g. client input) diff --git a/crates/corro-types/src/config.rs b/crates/corro-types/src/config.rs index ef56afd6..f7586101 100644 --- a/crates/corro-types/src/config.rs +++ b/crates/corro-types/src/config.rs @@ -87,6 +87,14 @@ pub struct ApiConfig { pub bind_addr: SocketAddr, #[serde(alias = "authz", default)] pub authorization: Option, + #[serde(default)] + pub pg: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PgConfig { + #[serde(alias = "addr")] + pub bind_addr: SocketAddr, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -273,6 +281,7 @@ impl ConfigBuilder { api: ApiConfig { bind_addr: self.api_addr.ok_or(ConfigBuilderError::ApiAddrRequired)?, authorization: None, + pg: None, }, gossip: GossipConfig { bind_addr: self diff --git a/crates/corro-types/src/pubsub.rs b/crates/corro-types/src/pubsub.rs index 04b20fe3..176ddec6 100644 --- a/crates/corro-types/src/pubsub.rs +++ b/crates/corro-types/src/pubsub.rs @@ -25,7 +25,7 @@ use uuid::Uuid; use crate::{ api::QueryEvent, - schema::{NormalizedSchema, NormalizedTable}, + schema::{Schema, Table}, sqlite::Migration, }; @@ -356,7 +356,7 @@ const CHANGE_TYPE_COL: &str = "type"; impl Matcher { fn new( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, conn: &Connection, evt_tx: mpsc::Sender, sql: &str, @@ -530,7 +530,7 @@ impl Matcher { pub fn restore( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, conn: Connection, evt_tx: mpsc::Sender, sql: &str, @@ -544,7 +544,7 @@ impl Matcher { pub fn create( id: Uuid, - schema: &NormalizedSchema, + schema: &Schema, mut conn: Connection, evt_tx: mpsc::Sender, sql: &str, @@ -1040,10 +1040,7 @@ pub struct ParsedSelect { children: Vec, } -fn extract_select_columns( - select: &Select, - schema: &NormalizedSchema, -) -> Result { +fn extract_select_columns(select: &Select, schema: &Schema) -> Result { let mut parsed = ParsedSelect::default(); if let OneSelect::Select { @@ -1140,7 +1137,7 @@ fn extract_select_columns( fn extract_expr_columns( expr: &Expr, - schema: &NormalizedSchema, + schema: &Schema, parsed: &mut ParsedSelect, ) -> Result<(), MatcherError> { match expr { @@ -1318,7 +1315,7 @@ fn extract_expr_columns( fn extract_columns( columns: &[ResultColumn], from: Option<&Name>, - schema: &NormalizedSchema, + schema: &Schema, parsed: &mut ParsedSelect, ) -> Result<(), MatcherError> { let mut i = 0; @@ -1382,7 +1379,7 @@ fn extract_columns( fn table_to_expr( aliases: &HashMap, - tbl: &NormalizedTable, + tbl: &Table, table: &str, id: Uuid, ) -> Result { @@ -1521,7 +1518,7 @@ mod tests { { let tx = conn.transaction()?; - apply_schema(&tx, &NormalizedSchema::default(), &mut schema)?; + apply_schema(&tx, &Schema::default(), &mut schema)?; tx.commit()?; } @@ -1653,7 +1650,7 @@ mod tests { { let tx = conn.transaction().unwrap(); - apply_schema(&tx, &NormalizedSchema::default(), &mut schema).unwrap(); + apply_schema(&tx, &Schema::default(), &mut schema).unwrap(); tx.commit().unwrap(); } @@ -1695,7 +1692,7 @@ mod tests { { let tx = conn2.transaction().unwrap(); - apply_schema(&tx, &NormalizedSchema::default(), &mut schema).unwrap(); + apply_schema(&tx, &Schema::default(), &mut schema).unwrap(); tx.commit().unwrap(); } diff --git a/crates/corro-types/src/schema.rs b/crates/corro-types/src/schema.rs index 5b4745ba..3352c826 100644 --- a/crates/corro-types/src/schema.rs +++ b/crates/corro-types/src/schema.rs @@ -12,10 +12,10 @@ use sqlite3_parser::ast::{ Cmd, ColumnConstraint, ColumnDefinition, CreateTableBody, Expr, Name, NamedTableConstraint, QualifiedName, SortedColumn, Stmt, TableConstraint, TableOptions, ToTokens, }; -use tracing::{debug, info}; +use tracing::{debug, info, trace}; #[derive(Debug, Clone, Eq, PartialEq)] -pub struct NormalizedColumn { +pub struct Column { pub name: String, pub sql_type: SqliteType, pub nullable: bool, @@ -25,7 +25,7 @@ pub struct NormalizedColumn { pub raw: ColumnDefinition, } -impl std::hash::Hash for NormalizedColumn { +impl std::hash::Hash for Column { fn hash(&self, state: &mut H) { self.name.hash(state); self.sql_type.hash(state); @@ -36,7 +36,7 @@ impl std::hash::Hash for NormalizedColumn { } } -impl fmt::Display for NormalizedColumn { +impl fmt::Display for Column { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.raw.to_fmt(f) } @@ -44,7 +44,7 @@ impl fmt::Display for NormalizedColumn { /// SQLite data types. /// See [Fundamental Datatypes](https://sqlite.org/c3ref/c_blob.html). -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum SqliteType { /// NULL @@ -60,15 +60,15 @@ pub enum SqliteType { } #[derive(Debug, Clone)] -pub struct NormalizedTable { +pub struct Table { pub name: String, pub pk: IndexSet, - pub columns: IndexMap, - pub indexes: IndexMap, + pub columns: IndexMap, + pub indexes: IndexMap, pub raw: CreateTableBody, } -impl fmt::Display for NormalizedTable { +impl fmt::Display for Table { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Cmd::Stmt(Stmt::CreateTable { temporary: false, @@ -81,16 +81,78 @@ impl fmt::Display for NormalizedTable { } #[derive(Debug, Clone, Eq, PartialEq)] -pub struct NormalizedIndex { +pub struct Index { pub name: String, pub tbl_name: String, pub columns: Vec, pub where_clause: Option, + pub unique: bool, } #[derive(Debug, Clone, Default)] -pub struct NormalizedSchema { - pub tables: IndexMap, +pub struct Schema { + pub tables: IndexMap, +} + +impl Schema { + pub fn constrain(&mut self) -> Result<(), ConstrainedSchemaError> { + self.tables.retain(|name, _table| { + !(name.contains("crsql") && name.contains("sqlite") && name.starts_with("__corro")) + }); + + for (tbl_name, table) in self.tables.iter() { + // this should always be the case... + if let CreateTableBody::ColumnsAndConstraints { + columns: _, + constraints, + options: _, + } = &table.raw + { + if let Some(constraints) = constraints { + for named in constraints.iter() { + if let TableConstraint::PrimaryKey { columns, .. } = &named.constraint { + for column in columns.iter() { + if !matches!(column.expr, Expr::Id(_)) { + return Err(ConstrainedSchemaError::PrimaryKeyExpr); + } + } + } + } + } + } else { + // error here! + } + + for (name, column) in table.columns.iter() { + if !column.primary_key && !column.nullable && column.default_value.is_none() { + return Err(ConstrainedSchemaError::NotNullableColumnNeedsDefault { + tbl_name: tbl_name.clone(), + name: name.clone(), + }); + } + + if column + .raw + .constraints + .iter() + .any(|named| matches!(named.constraint, ColumnConstraint::ForeignKey { .. })) + { + return Err(ConstrainedSchemaError::ForeignKey { + tbl_name: tbl_name.clone(), + name: name.clone(), + }); + } + } + + for (name, index) in table.indexes.iter() { + if index.unique { + return Err(ConstrainedSchemaError::UniqueIndex(name.clone())); + } + } + } + + Ok(()) + } } #[derive(Debug, thiserror::Error)] @@ -103,51 +165,30 @@ pub enum SchemaError { Parse(#[from] sqlite3_parser::lexer::sql::Error), #[error("nothing to parse")] NothingParsed, - #[error("unsupported statement: {0}")] + #[error("unsupported command: {0}")] UnsupportedCmd(Cmd), - #[error("unique indexes are not supported: {0}")] - UniqueIndex(Cmd), + #[error("missing table for index (table: '{tbl_name}', index: '{name}')")] + IndexWithoutTable { tbl_name: String, name: String }, #[error("temporary tables are not supported: {0}")] TemporaryTable(Cmd), +} + +#[derive(Debug, thiserror::Error)] +pub enum ConstrainedSchemaError { + #[error("unique indexes are not supported: {0}")] + UniqueIndex(String), #[error("table as select arenot supported: {0}")] TableAsSelect(Cmd), #[error("not nullable column '{name}' on table '{tbl_name}' needs a default value for forward schema compatibility")] NotNullableColumnNeedsDefault { tbl_name: String, name: String }, #[error("foreign keys are not supported (table: '{tbl_name}', column: '{name}')")] ForeignKey { tbl_name: String, name: String }, - #[error("missing table for index (table: '{tbl_name}', index: '{name}')")] - IndexWithoutTable { tbl_name: String, name: String }, #[error("expr used as primary")] PrimaryKeyExpr, - #[error("won't drop table without the destructive flag set (table: '{0}')")] - DropTableWithoutDestructiveFlag(String), - #[error("won't drop table without the destructive flag set (table: '{0}', column: '{1}')")] - DropColumnWithoutDestructiveFlag(String, String), - #[error("can't add a primary key (table: '{0}', column: '{1}')")] - AddPrimaryKey(String, String), - #[error("can't modify primary keys (table: '{0}')")] - ModifyPrimaryKeys(String), - - #[error("tried importing an existing schema for table '{0}' due to a failed CREATE TABLE but didn't find anything (this should never happen)")] - ImportedSchemaNotFound(String), - - #[error("existing schema for table '{tbl_name}' primary keys mismatched, expected: {expected:?}, got: {got:?}")] - ImportedSchemaPkMismatch { - tbl_name: String, - expected: IndexSet, - got: IndexSet, - }, - - #[error("existing schema for table '{tbl_name}' columns mismatched, expected: {expected:?}, got: {got:?}")] - ImportedSchemaColumnsMismatch { - tbl_name: String, - expected: IndexMap, - got: IndexMap, - }, } #[allow(clippy::result_large_err)] -pub fn init_schema(conn: &Connection) -> Result { +pub fn init_schema(conn: &Connection) -> Result { let mut dump = String::new(); let tables: HashMap = conn @@ -177,12 +218,47 @@ pub fn init_schema(conn: &Connection) -> Result { parse_sql(dump.as_str()) } +#[derive(Debug, thiserror::Error)] +pub enum ApplySchemaError { + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + Schema(#[from] SchemaError), + #[error(transparent)] + ConstrainedSchema(#[from] ConstrainedSchemaError), + #[error("won't drop table without the destructive flag set (table: '{0}')")] + DropTableWithoutDestructiveFlag(String), + #[error("won't drop table without the destructive flag set (table: '{0}', column: '{1}')")] + DropColumnWithoutDestructiveFlag(String, String), + #[error("can't add a primary key (table: '{0}', column: '{1}')")] + AddPrimaryKey(String, String), + #[error("can't modify primary keys (table: '{0}')")] + ModifyPrimaryKeys(String), + + #[error("tried importing an existing schema for table '{0}' due to a failed CREATE TABLE but didn't find anything (this should never happen)")] + ImportedSchemaNotFound(String), + + #[error("existing schema for table '{tbl_name}' primary keys mismatched, expected: {expected:?}, got: {got:?}")] + ImportedSchemaPkMismatch { + tbl_name: String, + expected: IndexSet, + got: IndexSet, + }, + + #[error("existing schema for table '{tbl_name}' columns mismatched, expected: {expected:?}, got: {got:?}")] + ImportedSchemaColumnsMismatch { + tbl_name: String, + expected: IndexMap, + got: IndexMap, + }, +} + #[allow(clippy::result_large_err)] pub fn apply_schema( tx: &Transaction, - schema: &NormalizedSchema, - new_schema: &mut NormalizedSchema, -) -> Result<(), SchemaError> { + schema: &Schema, + new_schema: &mut Schema, +) -> Result<(), ApplySchemaError> { if let Some(name) = schema .tables .keys() @@ -191,12 +267,12 @@ pub fn apply_schema( .next() { // TODO: add options and check flag - return Err(SchemaError::DropTableWithoutDestructiveFlag( + return Err(ApplySchemaError::DropTableWithoutDestructiveFlag( (*name).clone(), )); } - let mut schema_to_merge = NormalizedSchema::default(); + let mut schema_to_merge = Schema::default(); { let new_table_names = new_schema @@ -245,10 +321,10 @@ pub fn apply_schema( let parsed_table = parse_sql(&sql)? .tables .remove(name) - .ok_or_else(|| SchemaError::ImportedSchemaNotFound(name.clone()))?; + .ok_or_else(|| ApplySchemaError::ImportedSchemaNotFound(name.clone()))?; if parsed_table.pk != table.pk { - return Err(SchemaError::ImportedSchemaPkMismatch { + return Err(ApplySchemaError::ImportedSchemaPkMismatch { tbl_name: name.clone(), expected: table.pk.clone(), got: parsed_table.pk, @@ -256,7 +332,7 @@ pub fn apply_schema( } if parsed_table.columns != table.columns { - return Err(SchemaError::ImportedSchemaColumnsMismatch { + return Err(ApplySchemaError::ImportedSchemaColumnsMismatch { tbl_name: name.clone(), expected: table.columns.clone(), got: parsed_table.columns, @@ -327,7 +403,7 @@ pub fn apply_schema( debug!("dropped cols: {dropped_cols:?}"); if let Some(col_name) = dropped_cols.into_iter().next() { - return Err(SchemaError::DropColumnWithoutDestructiveFlag( + return Err(ApplySchemaError::DropColumnWithoutDestructiveFlag( name.clone(), col_name.clone(), )); @@ -335,7 +411,7 @@ pub fn apply_schema( // 2. check for changed columns - let changed_cols: HashMap = table + let changed_cols: HashMap = table .columns .iter() .filter_map(|(name, col)| { @@ -379,13 +455,17 @@ pub fn apply_schema( for (col_name, col) in new_cols_iter { info!("adding column '{col_name}'"); if col.primary_key { - return Err(SchemaError::AddPrimaryKey(name.clone(), col_name.clone())); + return Err(ApplySchemaError::AddPrimaryKey( + name.clone(), + col_name.clone(), + )); } if !col.nullable && col.default_value.is_none() { - return Err(SchemaError::NotNullableColumnNeedsDefault { + return Err(ConstrainedSchemaError::NotNullableColumnNeedsDefault { tbl_name: name.clone(), name: col_name.clone(), - }); + } + .into()); } tx.execute_batch(&format!("ALTER TABLE {name} ADD COLUMN {}", col))?; } @@ -415,7 +495,7 @@ pub fn apply_schema( .collect::>(); if primary_keys != new_primary_keys { - return Err(SchemaError::ModifyPrimaryKeys(name.clone())); + return Err(ApplySchemaError::ModifyPrimaryKeys(name.clone())); } // "12-step" process to modifying a table @@ -537,8 +617,8 @@ pub fn apply_schema( } #[allow(clippy::result_large_err)] -pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<(), SchemaError> { - debug!("parsing {sql}"); +pub fn parse_sql_to_schema(schema: &mut Schema, sql: &str) -> Result<(), SchemaError> { + trace!("parsing {sql}"); let mut parser = sqlite3_parser::lexer::sql::Parser::new(sql.as_bytes()); loop { @@ -549,9 +629,6 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( return Err(err.into()); } Ok(Some(ref cmd @ Cmd::Stmt(ref stmt))) => match stmt { - Stmt::CreateIndex { unique: true, .. } => { - return Err(SchemaError::UniqueIndex(cmd.clone())) - } Stmt::CreateTable { temporary: true, .. } => return Err(SchemaError::TemporaryTable(cmd.clone())), @@ -570,29 +647,31 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( options, }, } => { - if let Some(table) = - prepare_table(tbl_name, columns, constraints.as_ref(), options)? - { - schema.tables.insert(tbl_name.name.0.clone(), table); - debug!("inserted table: {}", tbl_name.name.0); - } else { - debug!("skipped table: {}", tbl_name.name.0); - } + schema.tables.insert( + tbl_name.name.0.clone(), + prepare_table(tbl_name, columns, constraints.as_ref(), options), + ); + trace!("inserted table: {}", tbl_name.name.0); } Stmt::CreateIndex { - unique: false, - if_not_exists: _, + unique, idx_name, tbl_name, columns, where_clause, + .. } => { if let Some(table) = schema.tables.get_mut(tbl_name.0.as_str()) { - if let Some(index) = - prepare_index(idx_name, tbl_name, columns, where_clause.as_ref())? - { - table.indexes.insert(idx_name.name.0.clone(), index); - } + table.indexes.insert( + idx_name.name.0.clone(), + Index { + name: idx_name.name.0.clone(), + tbl_name: tbl_name.0.clone(), + columns: columns.to_vec(), + where_clause: where_clause.clone(), + unique: *unique, + }, + ); } else { return Err(SchemaError::IndexWithoutTable { tbl_name: tbl_name.0.clone(), @@ -610,53 +689,21 @@ pub fn parse_sql_to_schema(schema: &mut NormalizedSchema, sql: &str) -> Result<( } #[allow(clippy::result_large_err)] -pub fn parse_sql(sql: &str) -> Result { - let mut schema = NormalizedSchema::default(); +pub fn parse_sql(sql: &str) -> Result { + let mut schema = Schema::default(); parse_sql_to_schema(&mut schema, sql)?; Ok(schema) } -#[allow(clippy::result_large_err)] -fn prepare_index( - name: &QualifiedName, - tbl_name: &Name, - columns: &[SortedColumn], - where_clause: Option<&Expr>, -) -> Result, SchemaError> { - debug!("preparing index: {}", name.name.0); - if tbl_name.0.contains("crsql") - & tbl_name.0.contains("sqlite") - & tbl_name.0.starts_with("__corro") - { - return Ok(None); - } - - Ok(Some(NormalizedIndex { - name: name.name.0.clone(), - tbl_name: tbl_name.0.clone(), - columns: columns.to_vec(), - where_clause: where_clause.cloned(), - })) -} - #[allow(clippy::result_large_err)] fn prepare_table( tbl_name: &QualifiedName, columns: &[ColumnDefinition], constraints: Option<&Vec>, options: &TableOptions, -) -> Result, SchemaError> { - debug!("preparing table: {}", tbl_name.name.0); - if tbl_name.name.0.contains("crsql") - & tbl_name.name.0.contains("sqlite") - & tbl_name.name.0.starts_with("__corro") - { - debug!("skipping table because of name"); - return Ok(None); - } - +) -> Table { let pk = constraints .and_then(|constraints| { constraints @@ -665,36 +712,34 @@ fn prepare_table( TableConstraint::PrimaryKey { columns, .. } => Some( columns .iter() - .map(|col| match &col.expr { - Expr::Id(id) => Ok(id.0.clone()), - _ => Err(SchemaError::PrimaryKeyExpr), + .filter_map(|col| match &col.expr { + Expr::Id(id) => Some(id.0.clone()), + _ => None, }) - .collect::, SchemaError>>(), + .collect::>(), ), _ => None, }) }) .unwrap_or_else(|| { - Ok(columns + columns .iter() - .filter_map(|def| { - def.constraints - .iter() - .any(|named| { - matches!(named.constraint, ColumnConstraint::PrimaryKey { .. }) - }) - .then(|| def.col_name.0.clone()) + .filter(|&def| { + def.constraints.iter().any(|named| { + matches!(named.constraint, ColumnConstraint::PrimaryKey { .. }) + }) }) - .collect()) - })?; + .map(|def| def.col_name.0.clone()) + .collect() + }); - Ok(Some(NormalizedTable { + Table { name: tbl_name.name.0.clone(), indexes: IndexMap::new(), columns: columns .iter() .map(|def| { - debug!("visiting column: {}", def.col_name.0); + trace!("visiting column: {}", def.col_name.0); let default_value = def.constraints.iter().find_map(|named| { if let ColumnConstraint::Default(ref expr) = named.constraint { Some(expr.to_string()) @@ -716,27 +761,9 @@ fn prepare_table( let primary_key = pk.contains(&def.col_name.0); - if !primary_key && (!nullable && default_value.is_none()) { - return Err(SchemaError::NotNullableColumnNeedsDefault { - tbl_name: tbl_name.name.0.clone(), - name: def.col_name.0.clone(), - }); - } - - if def - .constraints - .iter() - .any(|named| matches!(named.constraint, ColumnConstraint::ForeignKey { .. })) - { - return Err(SchemaError::ForeignKey { - tbl_name: tbl_name.name.0.clone(), - name: def.col_name.0.clone(), - }); - } - - Ok(( + ( def.col_name.0.clone(), - NormalizedColumn { + Column { name: def.col_name.0.clone(), sql_type: match def .col_type @@ -785,14 +812,14 @@ fn prepare_table( }), raw: def.clone(), }, - )) + ) }) - .collect::, SchemaError>>()?, + .collect::>(), pk, raw: CreateTableBody::ColumnsAndConstraints { columns: columns.to_vec(), constraints: constraints.cloned(), options: *options, }, - })) + } } diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 9a49bd6e..74bac6bf 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -26,6 +26,7 @@ - [POST /v1/transactions](api/transactions.md) - [POST /v1/queries](api/queries.md) - [POST /v1/subscriptions](api/subscriptions.md) + - [PostgreSQL Wire Protocol](api/pg.md) - [Command-line Interface](cli/README.md) - [agent](cli/agent.md) - [backup](cli/backup.md) diff --git a/doc/api/pg.md b/doc/api/pg.md new file mode 100644 index 00000000..6646c657 --- /dev/null +++ b/doc/api/pg.md @@ -0,0 +1,15 @@ +# PostgreSQL Wire Protocol v3 API (experimental) + +It's possible to configure a PostgreSQL wire protocol compatible API listener via the `api.pg.addr` setting. + +This is currently experimental, but it does work for most queries that are SQLite-flavored SQL. + +## What works + +- Read and write queries, parsable as SQLite-flavored SQL +- Most parameter bindings, but not all (work in progress) + +## Does not work + +- Any PostgreSQL-only SQL syntax +- Some placement of variable parameters (when binding) \ No newline at end of file