From 645f6c0f8eae94d111c933a43d2c303c6113e7d7 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Jul 2025 17:43:30 +0200 Subject: [PATCH 1/9] refactor: drop change.rs <3 --- .claude/settings.local.json | 3 +- Cargo.lock | 146 +- Cargo.toml | 1 + .../execute/process_file/workspace_file.rs | 18 +- crates/pgt_lsp/src/capabilities.rs | 18 +- crates/pgt_lsp/src/handlers/text_document.rs | 37 +- crates/pgt_statement_splitter/Cargo.toml | 7 +- .../benches/splitter.rs | 80 + crates/pgt_workspace/Cargo.toml | 1 + .../src/features/code_actions.rs | 2 +- .../pgt_workspace/src/features/completions.rs | 10 +- crates/pgt_workspace/src/workspace.rs | 27 +- crates/pgt_workspace/src/workspace/server.rs | 59 +- .../src/workspace/server/annotation.rs | 8 - .../src/workspace/server/change.rs | 1648 ----------------- .../src/workspace/server/document.rs | 380 +++- .../src/workspace/server/parsed_document.rs | 442 ----- .../src/workspace/server/pg_query.rs | 11 +- .../workspace/server/statement_identifier.rs | 128 +- .../src/workspace/server/tree_sitter.rs | 131 +- 20 files changed, 670 insertions(+), 2487 deletions(-) create mode 100644 crates/pgt_statement_splitter/benches/splitter.rs delete mode 100644 crates/pgt_workspace/src/workspace/server/change.rs delete mode 100644 crates/pgt_workspace/src/workspace/server/parsed_document.rs diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 85429d0c6..591b91192 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -6,7 +6,8 @@ "Bash(cargo test:*)", "Bash(cargo run:*)", "Bash(cargo check:*)", - "Bash(cargo fmt:*)" + "Bash(cargo fmt:*)", + "Bash(cargo doc:*)" ], "deny": [] } diff --git a/Cargo.lock b/Cargo.lock index 074ed19b0..db1c361d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -290,6 +290,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "auto_impl" version = "1.2.0" @@ -808,7 +819,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" dependencies = [ "ciborium-io", - "half", + "half 2.6.0", ] [[package]] @@ -822,6 +833,17 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + [[package]] name = "clap" version = "4.5.23" @@ -943,6 +965,32 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot 0.4.5", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + [[package]] name = "criterion" version = "0.5.1" @@ -952,8 +1000,8 @@ dependencies = [ "anes", "cast", "ciborium", - "clap", - "criterion-plot", + "clap 4.5.23", + "criterion-plot 0.5.0", "is-terminal", "itertools 0.10.5", "num-traits", @@ -969,6 +1017,16 @@ dependencies = [ "walkdir", ] +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "criterion-plot" version = "0.5.0" @@ -1051,6 +1109,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +dependencies = [ + "memchr", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -1375,6 +1454,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1615,6 +1700,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.6.0" @@ -1646,6 +1737,9 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "foldhash", +] [[package]] name = "hashlink" @@ -1672,6 +1766,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -2643,7 +2746,7 @@ name = "pgt_completions" version = "0.0.0" dependencies = [ "async-std", - "criterion", + "criterion 0.5.1", "fuzzy-matcher", "pgt_schema_cache", "pgt_test_utils", @@ -2859,6 +2962,7 @@ dependencies = [ name = "pgt_statement_splitter" version = "0.0.0" dependencies = [ + "criterion 0.3.6", "ntest", "pgt_diagnostics", "pgt_lexer", @@ -2883,7 +2987,7 @@ name = "pgt_test_utils" version = "0.0.0" dependencies = [ "anyhow", - "clap", + "clap 4.5.23", "dotenv", "sqlx", "tree-sitter", @@ -2922,7 +3026,7 @@ dependencies = [ name = "pgt_treesitter_queries" version = "0.0.0" dependencies = [ - "clap", + "clap 4.5.23", "tree-sitter", "tree_sitter_sql", ] @@ -2984,6 +3088,7 @@ dependencies = [ "serde_json", "slotmap", "sqlx", + "string-interner", "strum", "tempfile", "tokio", @@ -3718,6 +3823,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.215" @@ -4152,6 +4267,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "string-interner" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23de088478b31c349c9ba67816fa55d9355232d63c3afea8bf513e31f0f1d2c0" +dependencies = [ + "hashbrown 0.15.2", + "serde", +] + [[package]] name = "stringprep" version = "0.1.5" @@ -4305,6 +4430,15 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index b5d6dd01f..0ecd4dd06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ serde_json = "1.0.114" similar = "2.6.0" slotmap = "1.0.7" smallvec = { version = "1.13.2", features = ["union", "const_new", "serde"] } +string-interner = "0.19.0" strum = { version = "0.27.1", features = ["derive"] } # this will use tokio if available, otherwise async-std convert_case = "0.6.0" diff --git a/crates/pgt_cli/src/execute/process_file/workspace_file.rs b/crates/pgt_cli/src/execute/process_file/workspace_file.rs index 790176b90..9f78c7cf1 100644 --- a/crates/pgt_cli/src/execute/process_file/workspace_file.rs +++ b/crates/pgt_cli/src/execute/process_file/workspace_file.rs @@ -2,13 +2,14 @@ use crate::execute::diagnostics::{ResultExt, ResultIoExt}; use crate::execute::process_file::SharedTraversalOptions; use pgt_diagnostics::{Error, category}; use pgt_fs::{File, OpenOptions, PgTPath}; -use pgt_workspace::workspace::{ChangeParams, FileGuard, OpenFileParams}; +use pgt_workspace::workspace::{FileGuard, OpenFileParams}; use pgt_workspace::{Workspace, WorkspaceError}; use std::path::{Path, PathBuf}; /// Small wrapper that holds information and operations around the current processed file pub(crate) struct WorkspaceFile<'ctx, 'app> { guard: FileGuard<'app, dyn Workspace + 'ctx>, + #[allow(dead_code)] file: Box, pub(crate) path: PathBuf, } @@ -57,19 +58,4 @@ impl<'ctx, 'app> WorkspaceFile<'ctx, 'app> { pub(crate) fn input(&self) -> Result { self.guard().get_file_content() } - - /// It updates the workspace file with `new_content` - #[allow(dead_code)] - pub(crate) fn update_file(&mut self, new_content: impl Into) -> Result<(), Error> { - let new_content = new_content.into(); - - self.file - .set_content(new_content.as_bytes()) - .with_file_path(self.path.display().to_string())?; - self.guard.change_file( - self.file.file_version(), - vec![ChangeParams::overwrite(new_content)], - )?; - Ok(()) - } } diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index acfc60edc..b50f2753f 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -1,14 +1,10 @@ use crate::adapters::{PositionEncoding, WideEncoding, negotiated_encoding}; -use pgt_workspace::features::code_actions::CommandActionCategory; -use strum::IntoEnumIterator; use tower_lsp::lsp_types::{ ClientCapabilities, CompletionOptions, ExecuteCommandOptions, PositionEncodingKind, SaveOptions, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions, WorkDoneProgressOptions, }; -use crate::handlers::code_actions::command_id; - /// The capabilities to send from server as part of [`InitializeResult`] /// /// [`InitializeResult`]: lspower::lsp::InitializeResult @@ -51,9 +47,7 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa }, }), execute_command_provider: Some(ExecuteCommandOptions { - commands: CommandActionCategory::iter() - .map(|c| command_id(&c)) - .collect::>(), + commands: available_command_ids(), ..Default::default() }), @@ -67,3 +61,13 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa ..Default::default() } } + +/// Returns all available command IDs for capability registration. +/// Since CommandActionCategory has variants with data, we can't use strum::IntoEnumIterator. +/// Instead, we manually list the command IDs we want to register. +fn available_command_ids() -> Vec { + vec![ + "postgres-tools.executeStatement".to_string(), + // Add other command IDs here as needed + ] +} diff --git a/crates/pgt_lsp/src/handlers/text_document.rs b/crates/pgt_lsp/src/handlers/text_document.rs index 63250ef5a..318d95724 100644 --- a/crates/pgt_lsp/src/handlers/text_document.rs +++ b/crates/pgt_lsp/src/handlers/text_document.rs @@ -1,10 +1,9 @@ -use crate::adapters::from_lsp; use crate::{ diagnostics::LspError, documents::Document, session::Session, utils::apply_document_changes, }; use anyhow::Result; use pgt_workspace::workspace::{ - ChangeFileParams, ChangeParams, CloseFileParams, GetFileContentParams, OpenFileParams, + ChangeFileParams, CloseFileParams, GetFileContentParams, OpenFileParams, }; use tower_lsp::lsp_types; use tracing::error; @@ -46,42 +45,30 @@ pub(crate) async fn did_change( let url = params.text_document.uri; let version = params.text_document.version; - let pgt_path = session.file_path(&url)?; + let biome_path = session.file_path(&url)?; - let old_doc = session.document(&url)?; let old_text = session.workspace.get_file_content(GetFileContentParams { - path: pgt_path.clone(), + path: biome_path.clone(), })?; - - let start = params - .content_changes - .iter() - .rev() - .position(|change| change.range.is_none()) - .map_or(0, |idx| params.content_changes.len() - idx - 1); + tracing::trace!("old document: {:?}", old_text); + tracing::trace!("content changes: {:?}", params.content_changes); let text = apply_document_changes( session.position_encoding(), old_text, - ¶ms.content_changes[start..], + ¶ms.content_changes, ); + tracing::trace!("new document: {:?}", text); + + session.insert_document(url.clone(), Document::new(version, &text)); + session.workspace.change_file(ChangeFileParams { - path: pgt_path, + path: biome_path, version, - changes: params.content_changes[start..] - .iter() - .map(|c| ChangeParams { - range: c.range.and_then(|r| { - from_lsp::text_range(&old_doc.line_index, r, session.position_encoding()).ok() - }), - text: c.text.clone(), - }) - .collect(), + content: text, })?; - session.insert_document(url.clone(), Document::new(version, &text)); - if let Err(err) = session.update_diagnostics(url).await { error!("Failed to update diagnostics: {}", err); } diff --git a/crates/pgt_statement_splitter/Cargo.toml b/crates/pgt_statement_splitter/Cargo.toml index deea07bb1..bdd892a60 100644 --- a/crates/pgt_statement_splitter/Cargo.toml +++ b/crates/pgt_statement_splitter/Cargo.toml @@ -19,4 +19,9 @@ pgt_text_size.workspace = true regex.workspace = true [dev-dependencies] -ntest = "0.9.3" +criterion = "0.3" +ntest = "0.9.3" + +[[bench]] +harness = false +name = "splitter" diff --git a/crates/pgt_statement_splitter/benches/splitter.rs b/crates/pgt_statement_splitter/benches/splitter.rs new file mode 100644 index 000000000..4a1cd7738 --- /dev/null +++ b/crates/pgt_statement_splitter/benches/splitter.rs @@ -0,0 +1,80 @@ +use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use pgt_statement_splitter::split; + +pub fn splitter_benchmark(c: &mut Criterion) { + c.bench_function("large statement", |b| { + let statement = r#"with + available_tables as ( + select + c.relname as table_name, + c.oid as table_oid, + c.relkind as class_kind, + n.nspname as schema_name + from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where + -- r: normal tables + -- v: views + -- m: materialized views + -- f: foreign tables + -- p: partitioned tables + c.relkind in ('r', 'v', 'm', 'f', 'p') + ), + available_indexes as ( + select + unnest (ix.indkey) as attnum, + ix.indisprimary as is_primary, + ix.indisunique as is_unique, + ix.indrelid as table_oid + from + pg_catalog.pg_class c + join pg_catalog.pg_index ix on c.oid = ix.indexrelid + where + c.relkind = 'i' + ) +select + atts.attname as name, + ts.table_name, + ts.table_oid :: int8 as "table_oid!", + ts.class_kind :: char as "class_kind!", + ts.schema_name, + atts.atttypid :: int8 as "type_id!", + not atts.attnotnull as "is_nullable!", + nullif( + information_schema._pg_char_max_length (atts.atttypid, atts.atttypmod), + -1 + ) as varchar_length, + pg_get_expr (def.adbin, def.adrelid) as default_expr, + coalesce(ix.is_primary, false) as "is_primary_key!", + coalesce(ix.is_unique, false) as "is_unique!", + pg_catalog.col_description (ts.table_oid, atts.attnum) as comment +from + pg_catalog.pg_attribute atts + join available_tables ts on atts.attrelid = ts.table_oid + left join available_indexes ix on atts.attrelid = ix.table_oid + and atts.attnum = ix.attnum + left join pg_catalog.pg_attrdef def on atts.attrelid = def.adrelid + and atts.attnum = def.adnum +where + -- system columns, such as `cmax` or `tableoid`, have negative `attnum`s + atts.attnum >= 0; + +"#; + + let content = statement.repeat(500); + + b.iter(|| black_box(split(&content))); + }); + + c.bench_function("small statement", |b| { + let statement = r#"select 1 from public.user where id = 1"#; + + let content = statement.repeat(500); + + b.iter(|| black_box(split(&content))); + }); +} + +criterion_group!(benches, splitter_benchmark); +criterion_main!(benches); diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index bfa413e31..80356819e 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -37,6 +37,7 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["raw_value"] } slotmap = { workspace = true, features = ["serde"] } sqlx.workspace = true +string-interner = { workspace = true, features = ["serde"] } strum = { workspace = true } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracing = { workspace = true, features = ["attributes", "log"] } diff --git a/crates/pgt_workspace/src/features/code_actions.rs b/crates/pgt_workspace/src/features/code_actions.rs index 22223dd3c..025153588 100644 --- a/crates/pgt_workspace/src/features/code_actions.rs +++ b/crates/pgt_workspace/src/features/code_actions.rs @@ -44,7 +44,7 @@ pub struct CommandAction { pub category: CommandActionCategory, } -#[derive(Debug, serde::Serialize, serde::Deserialize, strum::EnumIter)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] pub enum CommandActionCategory { ExecuteStatement(StatementId), diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index 53eb9eab0..2803382ef 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -4,7 +4,7 @@ use pgt_completions::CompletionItem; use pgt_fs::PgTPath; use pgt_text_size::{TextRange, TextSize}; -use crate::workspace::{GetCompletionsFilter, GetCompletionsMapper, ParsedDocument, StatementId}; +use crate::workspace::{Document, GetCompletionsFilter, GetCompletionsMapper, StatementId}; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -30,7 +30,7 @@ impl IntoIterator for CompletionsResult { } pub(crate) fn get_statement_for_completions( - doc: &ParsedDocument, + doc: &Document, position: TextSize, ) -> Option<(StatementId, TextRange, String, Arc)> { let count = doc.count(); @@ -79,13 +79,13 @@ mod tests { use pgt_fs::PgTPath; use pgt_text_size::TextSize; - use crate::workspace::ParsedDocument; + use crate::workspace::Document; use super::get_statement_for_completions; static CURSOR_POSITION: &str = "€"; - fn get_doc_and_pos(sql: &str) -> (ParsedDocument, TextSize) { + fn get_doc_and_pos(sql: &str) -> (Document, TextSize) { let pos = sql .find(CURSOR_POSITION) .expect("Please add cursor position to test sql"); @@ -93,7 +93,7 @@ mod tests { let pos: u32 = pos.try_into().unwrap(); ( - ParsedDocument::new( + Document::new( PgTPath::new("test.sql"), sql.replace(CURSOR_POSITION, ""), 5, diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 61d60a496..9206b39dc 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -4,7 +4,6 @@ pub use self::client::{TransportRequest, WorkspaceClient, WorkspaceTransport}; use pgt_analyse::RuleCategories; use pgt_configuration::{PartialConfiguration, RuleSelector}; use pgt_fs::PgTPath; -use pgt_text_size::TextRange; #[cfg(feature = "schema")] use schemars::{JsonSchema, SchemaGenerator, schema::Schema}; use serde::{Deserialize, Serialize}; @@ -25,7 +24,7 @@ mod client; mod server; pub use server::StatementId; -pub(crate) use server::parsed_document::*; +pub(crate) use server::document::*; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -46,21 +45,7 @@ pub struct CloseFileParams { pub struct ChangeFileParams { pub path: PgTPath, pub version: i32, - pub changes: Vec, -} - -#[derive(Debug, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] -pub struct ChangeParams { - /// The range of the file that changed. If `None`, the whole file changed. - pub range: Option, - pub text: String, -} - -impl ChangeParams { - pub fn overwrite(text: String) -> Self { - Self { range: None, text } - } + pub content: String, } #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -205,15 +190,11 @@ impl<'app, W: Workspace + ?Sized> FileGuard<'app, W> { Ok(Self { workspace, path }) } - pub fn change_file( - &self, - version: i32, - changes: Vec, - ) -> Result<(), WorkspaceError> { + pub fn change_file(&self, version: i32, content: String) -> Result<(), WorkspaceError> { self.workspace.change_file(ChangeFileParams { path: self.path.clone(), version, - changes, + content, }) } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index d0c8d13a6..b2c94145d 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -9,12 +9,11 @@ use analyser::AnalyserVisitorBuilder; use async_helper::run_async; use connection_manager::ConnectionManager; use dashmap::DashMap; -use document::Document; -use futures::{StreamExt, stream}; -use parsed_document::{ - AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, - ParsedDocument, SyncDiagnosticsMapper, +use document::{ + AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, Document, ExecuteStatementMapper, + SyncDiagnosticsMapper, }; +use futures::{StreamExt, stream}; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{ @@ -51,12 +50,10 @@ pub use statement_identifier::StatementId; mod analyser; mod annotation; mod async_helper; -mod change; mod connection_key; mod connection_manager; pub(crate) mod document; mod migration; -pub(crate) mod parsed_document; mod pg_query; mod schema_cache_manager; mod sql_function; @@ -70,7 +67,7 @@ pub(super) struct WorkspaceServer { /// Stores the schema cache for this workspace schema_cache: SchemaCacheManager, - parsed_documents: DashMap, + documents: DashMap, connection: ConnectionManager, } @@ -92,7 +89,7 @@ impl WorkspaceServer { pub(crate) fn new() -> Self { Self { settings: RwLock::default(), - parsed_documents: DashMap::default(), + documents: DashMap::default(), schema_cache: SchemaCacheManager::new(), connection: ConnectionManager::new(), } @@ -265,11 +262,9 @@ impl Workspace for WorkspaceServer { /// Add a new file to the workspace #[tracing::instrument(level = "info", skip_all, fields(path = params.path.as_path().as_os_str().to_str()), err)] fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> { - self.parsed_documents + self.documents .entry(params.path.clone()) - .or_insert_with(|| { - ParsedDocument::new(params.path.clone(), params.content, params.version) - }); + .or_insert_with(|| Document::new(params.content, params.version)); if let Some(project_key) = self.path_belongs_to_current_workspace(¶ms.path) { self.set_current_project(project_key); @@ -280,7 +275,7 @@ impl Workspace for WorkspaceServer { /// Remove a file from the workspace fn close_file(&self, params: super::CloseFileParams) -> Result<(), WorkspaceError> { - self.parsed_documents + self.documents .remove(¶ms.path) .ok_or_else(WorkspaceError::not_found)?; @@ -293,16 +288,16 @@ impl Workspace for WorkspaceServer { version = params.version ), err)] fn change_file(&self, params: super::ChangeFileParams) -> Result<(), WorkspaceError> { - let mut parser = - self.parsed_documents - .entry(params.path.clone()) - .or_insert(ParsedDocument::new( - params.path.clone(), - "".to_string(), - params.version, - )); - - parser.apply_change(params); + match self.documents.entry(params.path.clone()) { + dashmap::mapref::entry::Entry::Occupied(mut entry) => { + entry + .get_mut() + .update_content(params.content, params.version); + } + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Document::new(params.content, params.version)); + } + } Ok(()) } @@ -313,7 +308,7 @@ impl Workspace for WorkspaceServer { fn get_file_content(&self, params: GetFileContentParams) -> Result { let document = self - .parsed_documents + .documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; Ok(document.get_document_content().to_string()) @@ -328,7 +323,7 @@ impl Workspace for WorkspaceServer { params: code_actions::CodeActionsParams, ) -> Result { let parser = self - .parsed_documents + .documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -370,7 +365,7 @@ impl Workspace for WorkspaceServer { params: ExecuteStatementParams, ) -> Result { let parser = self - .parsed_documents + .documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -428,7 +423,7 @@ impl Workspace for WorkspaceServer { }; let parser = self - .parsed_documents + .documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -448,7 +443,7 @@ impl Workspace for WorkspaceServer { // sorry for the ugly code :( let async_results = run_async(async move { stream::iter(input) - .map(|(_id, range, content, ast, cst, sign)| { + .map(|(id, range, ast, cst, sign)| { let pool = pool.clone(); let path = path_clone.clone(); let schema_cache = Arc::clone(&schema_cache); @@ -456,7 +451,7 @@ impl Workspace for WorkspaceServer { if let Some(ast) = ast { pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: &content, + sql: &id.content(), ast: &ast, tree: &cst, schema_cache: schema_cache.as_ref(), @@ -592,7 +587,7 @@ impl Workspace for WorkspaceServer { params: GetCompletionsParams, ) -> Result { let parsed_doc = self - .parsed_documents + .documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; @@ -623,7 +618,7 @@ impl Workspace for WorkspaceServer { tracing::debug!( "Found {} completion items for statement with id {}", items.len(), - id.raw() + id.content() ); Ok(CompletionsResult { items }) diff --git a/crates/pgt_workspace/src/workspace/server/annotation.rs b/crates/pgt_workspace/src/workspace/server/annotation.rs index db6a8b3b2..f46700dbd 100644 --- a/crates/pgt_workspace/src/workspace/server/annotation.rs +++ b/crates/pgt_workspace/src/workspace/server/annotation.rs @@ -55,14 +55,6 @@ impl AnnotationStore { annotations } - - pub fn clear_statement(&self, id: &StatementId) { - self.db.remove(id); - - if let Some(child_id) = id.get_child_id() { - self.db.remove(&child_id); - } - } } #[cfg(test)] diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs deleted file mode 100644 index cc455134c..000000000 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ /dev/null @@ -1,1648 +0,0 @@ -use pgt_text_size::{TextLen, TextRange, TextSize}; -use std::ops::{Add, Sub}; - -use crate::workspace::{ChangeFileParams, ChangeParams}; - -use super::{Document, document, statement_identifier::StatementId}; - -#[derive(Debug, PartialEq, Eq)] -pub enum StatementChange { - Added(AddedStatement), - Deleted(StatementId), - Modified(ModifiedStatement), -} - -#[derive(Debug, PartialEq, Eq)] -pub struct AddedStatement { - pub stmt: StatementId, - pub text: String, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct ModifiedStatement { - pub old_stmt: StatementId, - pub old_stmt_text: String, - - pub new_stmt: StatementId, - pub new_stmt_text: String, - - pub change_range: TextRange, - pub change_text: String, -} - -impl StatementChange { - #[allow(dead_code)] - pub fn statement(&self) -> &StatementId { - match self { - StatementChange::Added(stmt) => &stmt.stmt, - StatementChange::Deleted(stmt) => stmt, - StatementChange::Modified(changed) => &changed.new_stmt, - } - } -} - -/// Returns all relevant details about the change and its effects on the current state of the document. -struct Affected { - /// Full range of the change, including the range of all statements that intersect with the change - affected_range: TextRange, - /// All indices of affected statement positions - affected_indices: Vec, - /// The index of the first statement position before the change, if any - prev_index: Option, - /// The index of the first statement position after the change, if any - next_index: Option, - /// the full affected range includng the prev and next statement - full_affected_range: TextRange, -} - -impl Document { - /// Applies a file change to the document and returns the affected statements - pub fn apply_file_change(&mut self, change: &ChangeFileParams) -> Vec { - // cleanup all diagnostics with every change because we cannot guarantee that they are still valid - // this is because we know their ranges only by finding slices within the content which is - // very much not guaranteed to result in correct ranges - self.diagnostics.clear(); - - // when we recieive more than one change, we need to push back the changes based on the - // total range of the previous ones. This is because the ranges are always related to the original state. - // BUT: only for the statement range changes, not for the text changes - // this is why we pass both varaints to apply_change - let mut changes = Vec::new(); - - let mut change_indices: Vec = (0..change.changes.len()).collect(); - change_indices.sort_by(|&a, &b| { - match (change.changes[a].range, change.changes[b].range) { - (Some(range_a), Some(range_b)) => range_b.start().cmp(&range_a.start()), - (Some(_), None) => std::cmp::Ordering::Greater, // full changes will never be sent in a batch so this does not matter - (None, Some(_)) => std::cmp::Ordering::Less, - (None, None) => std::cmp::Ordering::Equal, - } - }); - - // Sort changes by start position and process from last to first to avoid position invalidation - for &idx in &change_indices { - changes.extend(self.apply_change(&change.changes[idx])); - } - - self.version = change.version; - - changes - } - - /// Helper method to drain all positions and return them as deleted statements - fn drain_positions(&mut self) -> Vec { - self.positions - .drain(..) - .map(|(id, _)| StatementChange::Deleted(id)) - .collect() - } - - /// Applies a change to the document and returns the affected statements - /// - /// Will always assume its a full change and reparse the whole document - fn apply_full_change(&mut self, change: &ChangeParams) -> Vec { - let mut changes = Vec::new(); - - changes.extend(self.drain_positions()); - - self.content = change.apply_to_text(&self.content); - - let (ranges, diagnostics) = document::split_with_diagnostics(&self.content, None); - - self.diagnostics = diagnostics; - - // Do not add any statements if there is a fatal error - if self.has_fatal_error() { - return changes; - } - - changes.extend(ranges.into_iter().map(|range| { - let id = self.id_generator.next(); - let text = self.content[range].to_string(); - self.positions.push((id.clone(), range)); - - StatementChange::Added(AddedStatement { stmt: id, text }) - })); - - changes - } - - fn insert_statement(&mut self, range: TextRange) -> StatementId { - let pos = self - .positions - .binary_search_by(|(_, r)| r.start().cmp(&range.start())) - .unwrap_err(); - - let new_id = self.id_generator.next(); - self.positions.insert(pos, (new_id.clone(), range)); - - new_id - } - - /// Returns all relevant details about the change and its effects on the current state of the document. - /// - The affected range is the full range of the change, including the range of all statements that intersect with the change - /// - All indices of affected statement positions - /// - The index of the first statement position before the change, if any - /// - The index of the first statement position after the change, if any - /// - the full affected range includng the prev and next statement - fn get_affected( - &self, - change_range: TextRange, - content_size: TextSize, - diff_size: TextSize, - is_addition: bool, - ) -> Affected { - let mut start = change_range.start(); - let mut end = change_range.end().min(content_size); - - let is_trim = change_range.start() >= content_size; - - let mut affected_indices = Vec::new(); - let mut prev_index = None; - let mut next_index = None; - - for (index, (_, pos_range)) in self.positions.iter().enumerate() { - if pos_range.intersect(change_range).is_some() { - affected_indices.push(index); - start = start.min(pos_range.start()); - end = end.max(pos_range.end()); - } else if pos_range.end() <= change_range.start() { - prev_index = Some(index); - } else if pos_range.start() >= change_range.end() && next_index.is_none() { - next_index = Some(index); - break; - } - } - - if affected_indices.is_empty() && prev_index.is_none() { - // if there is no prev_index and no intersection -> use 0 - start = 0.into(); - } - - if affected_indices.is_empty() && next_index.is_none() { - // if there is no next_index and no intersection -> use content_size - end = content_size; - } - - let first_affected_stmt_start = prev_index - .map(|i| self.positions[i].1.start()) - .unwrap_or(start); - - let mut last_affected_stmt_end = next_index - .map(|i| self.positions[i].1.end()) - .unwrap_or_else(|| end); - - if is_addition { - end = end.add(diff_size); - last_affected_stmt_end = last_affected_stmt_end.add(diff_size); - } else if !is_trim { - end = end.sub(diff_size); - last_affected_stmt_end = last_affected_stmt_end.sub(diff_size) - }; - - Affected { - affected_range: { - let end = end.min(content_size); - TextRange::new(start.min(end), end) - }, - affected_indices, - prev_index, - next_index, - full_affected_range: TextRange::new( - first_affected_stmt_start, - last_affected_stmt_end - .min(content_size) - .max(first_affected_stmt_start), - ), - } - } - - fn move_ranges(&mut self, offset: TextSize, diff_size: TextSize, is_addition: bool) { - self.positions - .iter_mut() - .skip_while(|(_, r)| offset > r.start()) - .for_each(|(_, range)| { - let new_range = if is_addition { - range.add(diff_size) - } else { - range.sub(diff_size) - }; - - *range = new_range; - }); - } - - /// Applies a single change to the document and returns the affected statements - /// - /// * `change`: The range-adjusted change to use for statement changes - /// * `original_change`: The original change to use for text changes (yes, this is a bit confusing, and we might want to refactor this entire thing at some point.) - fn apply_change(&mut self, change: &ChangeParams) -> Vec { - // if range is none, we have a full change - if change.range.is_none() { - // doesnt matter what change since range is null - return self.apply_full_change(change); - } - - // i spent a relatively large amount of time thinking about how to handle range changes - // properly. there are quite a few edge cases to consider. I eventually skipped most of - // them, because the complexity is not worth the return for now. we might want to revisit - // this later though. - - let mut changed: Vec = Vec::with_capacity(self.positions.len()); - - let change_range = change.range.unwrap(); - let previous_content = self.content.clone(); - let new_content = change.apply_to_text(&self.content); - - // we first need to determine the affected range and all affected statements, as well as - // the index of the prev and the next statement, if any. The full affected range is the - // affected range expanded to the start of the previous statement and the end of the next - let Affected { - affected_range, - affected_indices, - prev_index, - next_index, - full_affected_range, - } = self.get_affected( - change_range, - new_content.text_len(), - change.diff_size(), - change.is_addition(), - ); - - // if within a statement, we can modify it if the change results in also a single statement - if affected_indices.len() == 1 { - let changed_content = get_affected(&new_content, affected_range); - - let (new_ranges, diags) = - document::split_with_diagnostics(changed_content, Some(affected_range.start())); - - self.diagnostics = diags; - - if self.has_fatal_error() { - // cleanup all positions if there is a fatal error - changed.extend(self.drain_positions()); - // still process text change - self.content = new_content; - return changed; - } - - if new_ranges.len() == 1 { - let affected_idx = affected_indices[0]; - let new_range = new_ranges[0].add(affected_range.start()); - let (old_id, old_range) = self.positions[affected_idx].clone(); - - // move all statements after the affected range - self.move_ranges(old_range.end(), change.diff_size(), change.is_addition()); - - let new_id = self.id_generator.next(); - self.positions[affected_idx] = (new_id.clone(), new_range); - - changed.push(StatementChange::Modified(ModifiedStatement { - old_stmt: old_id.clone(), - old_stmt_text: previous_content[old_range].to_string(), - - new_stmt: new_id, - new_stmt_text: changed_content[new_ranges[0]].to_string(), - // change must be relative to the statement - change_text: change.text.clone(), - // make sure we always have a valid range >= 0 - change_range: change_range - .checked_sub(old_range.start()) - .unwrap_or(change_range.sub(change_range.start())), - })); - - self.content = new_content; - - return changed; - } - } - - // in any other case, parse the full affected range - let changed_content = get_affected(&new_content, full_affected_range); - - let (new_ranges, diags) = - document::split_with_diagnostics(changed_content, Some(full_affected_range.start())); - - self.diagnostics = diags; - - if self.has_fatal_error() { - // cleanup all positions if there is a fatal error - changed.extend(self.drain_positions()); - // still process text change - self.content = new_content; - return changed; - } - - // delete and add new ones - if let Some(next_index) = next_index { - changed.push(StatementChange::Deleted( - self.positions[next_index].0.clone(), - )); - self.positions.remove(next_index); - } - for idx in affected_indices.iter().rev() { - changed.push(StatementChange::Deleted(self.positions[*idx].0.clone())); - self.positions.remove(*idx); - } - if let Some(prev_index) = prev_index { - changed.push(StatementChange::Deleted( - self.positions[prev_index].0.clone(), - )); - self.positions.remove(prev_index); - } - - new_ranges.iter().for_each(|range| { - let actual_range = range.add(full_affected_range.start()); - let new_id = self.insert_statement(actual_range); - changed.push(StatementChange::Added(AddedStatement { - stmt: new_id, - text: new_content[actual_range].to_string(), - })); - }); - - // move all statements after the afffected range - self.move_ranges( - full_affected_range.end(), - change.diff_size(), - change.is_addition(), - ); - - self.content = new_content; - - changed - } -} - -impl ChangeParams { - /// For lack of a better name, this returns the change in size of the text compared to the range - pub fn change_size(&self) -> i64 { - match self.range { - Some(range) => { - let range_length: usize = range.len().into(); - let text_length = self.text.chars().count(); - text_length as i64 - range_length as i64 - } - None => i64::try_from(self.text.chars().count()).unwrap(), - } - } - - pub fn diff_size(&self) -> TextSize { - match self.range { - Some(range) => { - let range_length: usize = range.len().into(); - let text_length = self.text.chars().count(); - let diff = (text_length as i64 - range_length as i64).abs(); - TextSize::from(u32::try_from(diff).unwrap()) - } - None => TextSize::from(u32::try_from(self.text.chars().count()).unwrap()), - } - } - - pub fn is_addition(&self) -> bool { - self.range.is_some() && self.text.len() > self.range.unwrap().len().into() - } - - pub fn is_deletion(&self) -> bool { - self.range.is_some() && self.text.len() < self.range.unwrap().len().into() - } - - pub fn apply_to_text(&self, text: &str) -> String { - if self.range.is_none() { - return self.text.clone(); - } - - let range = self.range.unwrap(); - let start = usize::from(range.start()); - let end = usize::from(range.end()); - - let mut new_text = String::new(); - new_text.push_str(&text[..start]); - new_text.push_str(&self.text); - if end < text.len() { - new_text.push_str(&text[end..]); - } - - new_text - } -} - -fn get_affected(content: &str, range: TextRange) -> &str { - let start_byte = content - .char_indices() - .nth(usize::from(range.start())) - .map(|(i, _)| i) - .unwrap_or(content.len()); - - let end_byte = content - .char_indices() - .nth(usize::from(range.end())) - .map(|(i, _)| i) - .unwrap_or(content.len()); - - &content[start_byte..end_byte] -} - -#[cfg(test)] -mod tests { - use super::*; - use pgt_text_size::TextRange; - - use crate::workspace::{ChangeFileParams, ChangeParams}; - - use pgt_fs::PgTPath; - - impl Document { - pub fn get_text(&self, idx: usize) -> String { - self.content[self.positions[idx].1.start().into()..self.positions[idx].1.end().into()] - .to_string() - } - } - - fn assert_document_integrity(d: &Document) { - let ranges = pgt_statement_splitter::split(&d.content).ranges; - - assert!( - ranges.len() == d.positions.len(), - "should have the correct amount of positions" - ); - - assert!( - ranges - .iter() - .all(|r| { d.positions.iter().any(|(_, stmt_range)| stmt_range == r) }), - "all ranges should be in positions" - ); - } - - #[test] - fn comments_at_begin() { - let path = PgTPath::new("test.sql"); - let input = "\nselect id from users;\n"; - - let mut d = Document::new(input.to_string(), 0); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "-".to_string(), - range: Some(TextRange::new(0.into(), 0.into())), - }], - }; - - let _changed1 = d.apply_file_change(&change1); - - assert_eq!(d.content, "-\nselect id from users;\n"); - assert_eq!(d.positions.len(), 2); - - let change2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "-".to_string(), - range: Some(TextRange::new(1.into(), 1.into())), - }], - }; - - let _changed2 = d.apply_file_change(&change2); - - assert_eq!(d.content, "--\nselect id from users;\n"); - assert_eq!(d.positions.len(), 1); - - let change3 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(2.into(), 2.into())), - }], - }; - - let _changed3 = d.apply_file_change(&change3); - - assert_eq!(d.content, "-- \nselect id from users;\n"); - assert_eq!(d.positions.len(), 1); - - let change4 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "t".to_string(), - range: Some(TextRange::new(3.into(), 3.into())), - }], - }; - - let _changed4 = d.apply_file_change(&change4); - - assert_eq!(d.content, "-- t\nselect id from users;\n"); - assert_eq!(d.positions.len(), 1); - - assert_document_integrity(&d); - } - - #[test] - fn typing_comments() { - let path = PgTPath::new("test.sql"); - let input = "select id from users;\n"; - - let mut d = Document::new(input.to_string(), 0); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "-".to_string(), - range: Some(TextRange::new(22.into(), 23.into())), - }], - }; - - let _changed1 = d.apply_file_change(&change1); - - assert_eq!(d.content, "select id from users;\n-"); - assert_eq!(d.positions.len(), 2); - - let change2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "-".to_string(), - range: Some(TextRange::new(23.into(), 24.into())), - }], - }; - - let _changed2 = d.apply_file_change(&change2); - - assert_eq!(d.content, "select id from users;\n--"); - assert_eq!(d.positions.len(), 1); - - let change3 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(24.into(), 25.into())), - }], - }; - - let _changed3 = d.apply_file_change(&change3); - - assert_eq!(d.content, "select id from users;\n-- "); - assert_eq!(d.positions.len(), 1); - - let change4 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "t".to_string(), - range: Some(TextRange::new(25.into(), 26.into())), - }], - }; - - let _changed4 = d.apply_file_change(&change4); - - assert_eq!(d.content, "select id from users;\n-- t"); - assert_eq!(d.positions.len(), 1); - - assert_document_integrity(&d); - } - - #[test] - fn within_statements() { - let path = PgTPath::new("test.sql"); - let input = "select id from users;\n\n\n\nselect * from contacts;"; - - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 2); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "select 1;".to_string(), - range: Some(TextRange::new(23.into(), 23.into())), - }], - }; - - let changed = d.apply_file_change(&change); - - assert_eq!(changed.len(), 5); - assert_eq!( - changed - .iter() - .filter(|c| matches!(c, StatementChange::Deleted(_))) - .count(), - 2 - ); - assert_eq!( - changed - .iter() - .filter(|c| matches!(c, StatementChange::Added(_))) - .count(), - 3 - ); - - assert_document_integrity(&d); - } - - #[test] - fn within_statements_2() { - let path = PgTPath::new("test.sql"); - let input = "alter table deal alter column value drop not null;\n"; - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 1); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(17.into(), 17.into())), - }], - }; - - let changed1 = d.apply_file_change(&change1); - assert_eq!(changed1.len(), 1); - assert_eq!( - d.content, - "alter table deal alter column value drop not null;\n" - ); - assert_document_integrity(&d); - - let change2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(18.into(), 18.into())), - }], - }; - - let changed2 = d.apply_file_change(&change2); - assert_eq!(changed2.len(), 1); - assert_eq!( - d.content, - "alter table deal alter column value drop not null;\n" - ); - assert_document_integrity(&d); - - let change3 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(19.into(), 19.into())), - }], - }; - - let changed3 = d.apply_file_change(&change3); - assert_eq!(changed3.len(), 1); - assert_eq!( - d.content, - "alter table deal alter column value drop not null;\n" - ); - assert_document_integrity(&d); - - let change4 = ChangeFileParams { - path: path.clone(), - version: 4, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(20.into(), 20.into())), - }], - }; - - let changed4 = d.apply_file_change(&change4); - assert_eq!(changed4.len(), 1); - assert_eq!( - d.content, - "alter table deal alter column value drop not null;\n" - ); - assert_document_integrity(&d); - } - - #[test] - fn julians_sample() { - let path = PgTPath::new("test.sql"); - let input = "select\n *\nfrom\n test;\n\nselect\n\nalter table test\n\ndrop column id;"; - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 4); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(31.into(), 31.into())), - }], - }; - - let changed1 = d.apply_file_change(&change1); - assert_eq!(changed1.len(), 1); - assert_eq!( - d.content, - "select\n *\nfrom\n test;\n\nselect \n\nalter table test\n\ndrop column id;" - ); - assert_document_integrity(&d); - - // problem: this creates a new statement - let change2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: ";".to_string(), - range: Some(TextRange::new(32.into(), 32.into())), - }], - }; - - let changed2 = d.apply_file_change(&change2); - assert_eq!(changed2.len(), 4); - assert_eq!( - changed2 - .iter() - .filter(|c| matches!(c, StatementChange::Deleted(_))) - .count(), - 2 - ); - assert_eq!( - changed2 - .iter() - .filter(|c| matches!(c, StatementChange::Added(_))) - .count(), - 2 - ); - assert_document_integrity(&d); - - let change3 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(32.into(), 33.into())), - }], - }; - - let changed3 = d.apply_file_change(&change3); - assert_eq!(changed3.len(), 1); - assert!(matches!(&changed3[0], StatementChange::Modified(_))); - assert_eq!( - d.content, - "select\n *\nfrom\n test;\n\nselect \n\nalter table test\n\ndrop column id;" - ); - match &changed3[0] { - StatementChange::Modified(changed) => { - assert_eq!(changed.old_stmt_text, "select ;"); - assert_eq!(changed.new_stmt_text, "select"); - assert_eq!(changed.change_text, ""); - assert_eq!(changed.change_range, TextRange::new(7.into(), 8.into())); - } - _ => panic!("expected modified statement"), - } - assert_document_integrity(&d); - } - - #[test] - fn across_statements() { - let path = PgTPath::new("test.sql"); - let input = "select id from users;\nselect * from contacts;"; - - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 2); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: ",test from users;\nselect 1;".to_string(), - range: Some(TextRange::new(9.into(), 45.into())), - }], - }; - - let changed = d.apply_file_change(&change); - - assert_eq!(changed.len(), 4); - assert!(matches!(changed[0], StatementChange::Deleted(_))); - assert_eq!(changed[0].statement().raw(), 1); - assert!(matches!( - changed[1], - StatementChange::Deleted(StatementId::Root(_)) - )); - assert_eq!(changed[1].statement().raw(), 0); - assert!( - matches!(&changed[2], StatementChange::Added(AddedStatement { stmt: _, text }) if text == "select id,test from users;") - ); - assert!( - matches!(&changed[3], StatementChange::Added(AddedStatement { stmt: _, text }) if text == "select 1;") - ); - - assert_document_integrity(&d); - } - - #[test] - fn append_whitespace_to_statement() { - let path = PgTPath::new("test.sql"); - let input = "select id"; - - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 1); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: " ".to_string(), - range: Some(TextRange::new(9.into(), 10.into())), - }], - }; - - let changed = d.apply_file_change(&change); - - assert_eq!(changed.len(), 1); - - assert_document_integrity(&d); - } - - #[test] - fn apply_changes() { - let path = PgTPath::new("test.sql"); - let input = "select id from users;\nselect * from contacts;"; - - let mut d = Document::new(input.to_string(), 0); - - assert_eq!(d.positions.len(), 2); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: ",test from users\nselect 1;".to_string(), - range: Some(TextRange::new(9.into(), 45.into())), - }], - }; - - let changed = d.apply_file_change(&change); - - assert_eq!(changed.len(), 4); - - assert!(matches!( - changed[0], - StatementChange::Deleted(StatementId::Root(_)) - )); - assert_eq!(changed[0].statement().raw(), 1); - assert!(matches!( - changed[1], - StatementChange::Deleted(StatementId::Root(_)) - )); - assert_eq!(changed[1].statement().raw(), 0); - assert_eq!( - changed[2], - StatementChange::Added(AddedStatement { - stmt: StatementId::Root(2.into()), - text: "select id,test from users".to_string() - }) - ); - assert_eq!( - changed[3], - StatementChange::Added(AddedStatement { - stmt: StatementId::Root(3.into()), - text: "select 1;".to_string() - }) - ); - - assert_eq!("select id,test from users\nselect 1;", d.content); - - assert_document_integrity(&d); - } - - #[test] - fn removing_newline_at_the_beginning() { - let path = PgTPath::new("test.sql"); - let input = "\n"; - - let mut d = Document::new(input.to_string(), 1); - - assert_eq!(d.positions.len(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "\nbegin;\n\nselect 1\n\nrollback;\n".to_string(), - range: Some(TextRange::new(0.into(), 1.into())), - }], - }; - - let changes = d.apply_file_change(&change); - - assert_eq!(changes.len(), 3); - - assert_document_integrity(&d); - - let change2 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(0.into(), 1.into())), - }], - }; - - let changes2 = d.apply_file_change(&change2); - - assert_eq!(changes2.len(), 1); - - assert_document_integrity(&d); - } - - #[test] - fn apply_changes_at_end_of_statement() { - let path = PgTPath::new("test.sql"); - let input = "select id from\nselect * from contacts;"; - - let mut d = Document::new(input.to_string(), 1); - - assert_eq!(d.positions.len(), 2); - - let change = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: " contacts;".to_string(), - range: Some(TextRange::new(14.into(), 14.into())), - }], - }; - - let changes = d.apply_file_change(&change); - - assert_eq!(changes.len(), 1); - - assert!(matches!(changes[0], StatementChange::Modified(_))); - - assert_eq!( - "select id from contacts;\nselect * from contacts;", - d.content - ); - - assert_document_integrity(&d); - } - - #[test] - fn apply_changes_replacement() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "select 1;\nselect 2;".to_string(), - range: None, - }], - }; - - doc.apply_file_change(&change); - - assert_eq!(doc.get_text(0), "select 1;".to_string()); - assert_eq!(doc.get_text(1), "select 2;".to_string()); - assert_eq!( - doc.positions[0].1, - TextRange::new(TextSize::new(0), TextSize::new(9)) - ); - assert_eq!( - doc.positions[1].1, - TextRange::new(TextSize::new(10), TextSize::new(19)) - ); - - let change_2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(7.into(), 8.into())), - }], - }; - - doc.apply_file_change(&change_2); - - assert_eq!(doc.content, "select ;\nselect 2;"); - assert_eq!(doc.positions.len(), 2); - assert_eq!(doc.get_text(0), "select ;".to_string()); - assert_eq!(doc.get_text(1), "select 2;".to_string()); - assert_eq!( - doc.positions[0].1, - TextRange::new(TextSize::new(0), TextSize::new(8)) - ); - assert_eq!( - doc.positions[1].1, - TextRange::new(TextSize::new(9), TextSize::new(18)) - ); - - let change_3 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "!".to_string(), - range: Some(TextRange::new(7.into(), 7.into())), - }], - }; - - doc.apply_file_change(&change_3); - - assert_eq!(doc.content, "select !;\nselect 2;"); - assert_eq!(doc.positions.len(), 2); - assert_eq!( - doc.positions[0].1, - TextRange::new(TextSize::new(0), TextSize::new(9)) - ); - assert_eq!( - doc.positions[1].1, - TextRange::new(TextSize::new(10), TextSize::new(19)) - ); - - let change_4 = ChangeFileParams { - path: path.clone(), - version: 4, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(7.into(), 8.into())), - }], - }; - - doc.apply_file_change(&change_4); - - assert_eq!(doc.content, "select ;\nselect 2;"); - assert_eq!(doc.positions.len(), 2); - assert_eq!( - doc.positions[0].1, - TextRange::new(TextSize::new(0), TextSize::new(8)) - ); - assert_eq!( - doc.positions[1].1, - TextRange::new(TextSize::new(9), TextSize::new(18)) - ); - - let change_5 = ChangeFileParams { - path: path.clone(), - version: 5, - changes: vec![ChangeParams { - text: "1".to_string(), - range: Some(TextRange::new(7.into(), 7.into())), - }], - }; - - doc.apply_file_change(&change_5); - - assert_eq!(doc.content, "select 1;\nselect 2;"); - assert_eq!(doc.positions.len(), 2); - assert_eq!( - doc.positions[0].1, - TextRange::new(TextSize::new(0), TextSize::new(9)) - ); - assert_eq!( - doc.positions[1].1, - TextRange::new(TextSize::new(10), TextSize::new(19)) - ); - - assert_document_integrity(&doc); - } - - #[test] - fn comment_at_begin() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new( - "-- Add new schema named \"private\"\nCREATE SCHEMA \"private\";".to_string(), - 0, - ); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(0.into(), 1.into())), - }], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!( - doc.content, - "- Add new schema named \"private\"\nCREATE SCHEMA \"private\";" - ); - assert_eq!(changed.len(), 3); - assert!(matches!(&changed[0], StatementChange::Deleted(_))); - assert!(matches!( - changed[1], - StatementChange::Added(AddedStatement { .. }) - )); - assert!(matches!( - changed[2], - StatementChange::Added(AddedStatement { .. }) - )); - - let change_2 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "-".to_string(), - range: Some(TextRange::new(0.into(), 0.into())), - }], - }; - - let changed_2 = doc.apply_file_change(&change_2); - - assert_eq!( - doc.content, - "-- Add new schema named \"private\"\nCREATE SCHEMA \"private\";" - ); - - assert_eq!(changed_2.len(), 3); - assert!(matches!( - changed_2[0], - StatementChange::Deleted(StatementId::Root(_)) - )); - assert!(matches!( - changed_2[1], - StatementChange::Deleted(StatementId::Root(_)) - )); - assert!(matches!( - changed_2[2], - StatementChange::Added(AddedStatement { .. }) - )); - - assert_document_integrity(&doc); - } - - #[test] - fn apply_changes_within_statement() { - let input = "select id from users;\nselect * from contacts;"; - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new(input.to_string(), 0); - - assert_eq!(doc.positions.len(), 2); - - let stmt_1_range = doc.positions[0].clone(); - let stmt_2_range = doc.positions[1].clone(); - - let update_text = ",test"; - - let update_range = TextRange::new(9.into(), 10.into()); - - let update_text_len = u32::try_from(update_text.chars().count()).unwrap(); - let update_addition = update_text_len - u32::from(update_range.len()); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: update_text.to_string(), - range: Some(update_range), - }], - }; - - doc.apply_file_change(&change); - - assert_eq!( - "select id,test from users;\nselect * from contacts;", - doc.content - ); - assert_eq!(doc.positions.len(), 2); - assert_eq!(doc.positions[0].1.start(), stmt_1_range.1.start()); - assert_eq!( - u32::from(doc.positions[0].1.end()), - u32::from(stmt_1_range.1.end()) + update_addition - ); - assert_eq!( - u32::from(doc.positions[1].1.start()), - u32::from(stmt_2_range.1.start()) + update_addition - ); - assert_eq!( - u32::from(doc.positions[1].1.end()), - u32::from(stmt_2_range.1.end()) + update_addition - ); - - assert_document_integrity(&doc); - } - - #[test] - fn remove_outside_of_content() { - let path = PgTPath::new("test.sql"); - let input = "select id from contacts;\n\nselect * from contacts;"; - - let mut d = Document::new(input.to_string(), 1); - - assert_eq!(d.positions.len(), 2); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 2, - changes: vec![ChangeParams { - text: "\n".to_string(), - range: Some(TextRange::new(49.into(), 49.into())), - }], - }; - - d.apply_file_change(&change1); - - assert_eq!( - d.content, - "select id from contacts;\n\nselect * from contacts;\n" - ); - - let change2 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "\n".to_string(), - range: Some(TextRange::new(50.into(), 50.into())), - }], - }; - - d.apply_file_change(&change2); - - assert_eq!( - d.content, - "select id from contacts;\n\nselect * from contacts;\n\n" - ); - - let change5 = ChangeFileParams { - path: path.clone(), - version: 6, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(51.into(), 52.into())), - }], - }; - - let changes = d.apply_file_change(&change5); - - assert!(matches!( - changes[0], - StatementChange::Deleted(StatementId::Root(_)) - )); - - assert!(matches!( - changes[1], - StatementChange::Added(AddedStatement { .. }) - )); - - assert_eq!(changes.len(), 2); - - assert_eq!( - d.content, - "select id from contacts;\n\nselect * from contacts;\n\n" - ); - - assert_document_integrity(&d); - } - - #[test] - fn remove_trailing_whitespace() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("select * from ".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(13.into(), 14.into())), - }], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!(doc.content, "select * from"); - - assert_eq!(changed.len(), 1); - - match &changed[0] { - StatementChange::Modified(stmt) => { - let ModifiedStatement { - change_range, - change_text, - new_stmt_text, - old_stmt_text, - .. - } = stmt; - - assert_eq!(change_range, &TextRange::new(13.into(), 14.into())); - assert_eq!(change_text, ""); - assert_eq!(new_stmt_text, "select * from"); - - // the whitespace was not considered - // to be a part of the statement - assert_eq!(old_stmt_text, "select * from"); - } - - _ => unreachable!("Did not yield a modified statement."), - } - - assert_document_integrity(&doc); - } - - #[test] - fn remove_trailing_whitespace_and_last_char() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("select * from ".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(12.into(), 14.into())), - }], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!(doc.content, "select * fro"); - - assert_eq!(changed.len(), 1); - - match &changed[0] { - StatementChange::Modified(stmt) => { - let ModifiedStatement { - change_range, - change_text, - new_stmt_text, - old_stmt_text, - .. - } = stmt; - - assert_eq!(change_range, &TextRange::new(12.into(), 14.into())); - assert_eq!(change_text, ""); - assert_eq!(new_stmt_text, "select * fro"); - - // the whitespace was not considered - // to be a part of the statement - assert_eq!(old_stmt_text, "select * from"); - } - - _ => unreachable!("Did not yield a modified statement."), - } - - assert_document_integrity(&doc); - } - - #[test] - fn multiple_deletions_at_once() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("ALTER TABLE ONLY public.omni_channel_message ADD CONSTRAINT omni_channel_message_organisation_id_fkey FOREIGN KEY (organisation_id) REFERENCES public.organisation(id) ON UPDATE RESTRICT ON DELETE CASCADE;".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ - ChangeParams { - range: Some(TextRange::new(60.into(), 80.into())), - text: "sendout".to_string(), - }, - ChangeParams { - range: Some(TextRange::new(24.into(), 44.into())), - text: "sendout".to_string(), - }, - ], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!( - doc.content, - "ALTER TABLE ONLY public.sendout ADD CONSTRAINT sendout_organisation_id_fkey FOREIGN KEY (organisation_id) REFERENCES public.organisation(id) ON UPDATE RESTRICT ON DELETE CASCADE;" - ); - - assert_eq!(changed.len(), 2); - - assert_document_integrity(&doc); - } - - #[test] - fn multiple_additions_at_once() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("ALTER TABLE ONLY public.sendout ADD CONSTRAINT sendout_organisation_id_fkey FOREIGN KEY (organisation_id) REFERENCES public.organisation(id) ON UPDATE RESTRICT ON DELETE CASCADE;".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ - ChangeParams { - range: Some(TextRange::new(47.into(), 54.into())), - text: "omni_channel_message".to_string(), - }, - ChangeParams { - range: Some(TextRange::new(24.into(), 31.into())), - text: "omni_channel_message".to_string(), - }, - ], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!( - doc.content, - "ALTER TABLE ONLY public.omni_channel_message ADD CONSTRAINT omni_channel_message_organisation_id_fkey FOREIGN KEY (organisation_id) REFERENCES public.organisation(id) ON UPDATE RESTRICT ON DELETE CASCADE;" - ); - - assert_eq!(changed.len(), 2); - - assert_document_integrity(&doc); - } - - #[test] - fn remove_inbetween_whitespace() { - let path = PgTPath::new("test.sql"); - - let mut doc = Document::new("select * from users".to_string(), 0); - - let change = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new(9.into(), 11.into())), - }], - }; - - let changed = doc.apply_file_change(&change); - - assert_eq!(doc.content, "select * from users"); - - assert_eq!(changed.len(), 1); - - match &changed[0] { - StatementChange::Modified(stmt) => { - let ModifiedStatement { - change_range, - change_text, - new_stmt_text, - old_stmt_text, - .. - } = stmt; - - assert_eq!(change_range, &TextRange::new(9.into(), 11.into())); - assert_eq!(change_text, ""); - assert_eq!(old_stmt_text, "select * from users"); - assert_eq!(new_stmt_text, "select * from users"); - } - - _ => unreachable!("Did not yield a modified statement."), - } - - assert_document_integrity(&doc); - } - - #[test] - fn test_another_issue() { - let path = PgTPath::new("test.sql"); - let initial_content = r#" - - - -ALTER TABLE ONLY "public"."campaign_contact_list" - ADD CONSTRAINT "campaign_contact_list_contact_list_id_fkey" FOREIGN KEY ("contact_list_id") REFERENCES "public"."contact_list"("id") ON UPDATE RESTRICT ON DELETE CASCADE; -"#; - - let mut doc = Document::new(initial_content.to_string(), 0); - - let change1 = ChangeFileParams { - path: path.clone(), - version: 1, - changes: vec![ - ChangeParams { - range: Some(TextRange::new(31.into(), 39.into())), - text: "journey_node".to_string(), - }, - ChangeParams { - range: Some(TextRange::new(74.into(), 82.into())), - text: "journey_node".to_string(), - }, - ], - }; - - let _changes = doc.apply_file_change(&change1); - - let expected_content = r#" - - - -ALTER TABLE ONLY "public"."journey_node_contact_list" - ADD CONSTRAINT "journey_node_contact_list_contact_list_id_fkey" FOREIGN KEY ("contact_list_id") REFERENCES "public"."contact_list"("id") ON UPDATE RESTRICT ON DELETE CASCADE; -"#; - - assert_eq!(doc.content, expected_content); - - assert_document_integrity(&doc); - } - - #[test] - fn test_comments_only() { - let path = PgTPath::new("test.sql"); - let initial_content = "-- atlas:import async_trigger/setup.sql\n-- atlas:import public/setup.sql\n-- atlas:import private/setup.sql\n-- atlas:import api/setup.sql\n-- atlas:import async_trigger/index.sql\n-- atlas:import public/enums/index.sql\n-- atlas:import public/types/index.sql\n-- atlas:import private/enums/index.sql\n-- atlas:import private/functions/index.sql\n-- atlas:import public/tables/index.sql\n-- atlas:import public/index.sql\n-- atlas:import private/index.sql\n-- atlas:import api/index.sql\n\n\n\n"; - - // Create a new document - let mut doc = Document::new(initial_content.to_string(), 0); - - // First change: Delete some text at line 2, character 24-29 - let change1 = ChangeFileParams { - path: path.clone(), - version: 3, - changes: vec![ChangeParams { - text: "".to_string(), - range: Some(TextRange::new( - // Calculate the correct position based on the content - // Line 2, character 24 - 98.into(), - // Line 2, character 29 - 103.into(), - )), - }], - }; - - let _changes1 = doc.apply_file_change(&change1); - - // Second change: Add 't' at line 2, character 24 - let change2 = ChangeFileParams { - path: path.clone(), - version: 4, - changes: vec![ChangeParams { - text: "t".to_string(), - range: Some(TextRange::new(98.into(), 98.into())), - }], - }; - - let _changes2 = doc.apply_file_change(&change2); - - assert_eq!( - doc.positions.len(), - 0, - "Document should have no statement after adding 't'" - ); - - // Third change: Add 'e' at line 2, character 25 - let change3 = ChangeFileParams { - path: path.clone(), - version: 5, - changes: vec![ChangeParams { - text: "e".to_string(), - range: Some(TextRange::new(99.into(), 99.into())), - }], - }; - - let _changes3 = doc.apply_file_change(&change3); - assert_eq!( - doc.positions.len(), - 0, - "Document should still have no statement" - ); - - // Fourth change: Add 's' at line 2, character 26 - let change4 = ChangeFileParams { - path: path.clone(), - version: 6, - changes: vec![ChangeParams { - text: "s".to_string(), - range: Some(TextRange::new(100.into(), 100.into())), - }], - }; - - let _changes4 = doc.apply_file_change(&change4); - assert_eq!( - doc.positions.len(), - 0, - "Document should still have no statement" - ); - - // Fifth change: Add 't' at line 2, character 27 - let change5 = ChangeFileParams { - path: path.clone(), - version: 7, - changes: vec![ChangeParams { - text: "t".to_string(), - range: Some(TextRange::new(101.into(), 101.into())), - }], - }; - - let _changes5 = doc.apply_file_change(&change5); - assert_eq!( - doc.positions.len(), - 0, - "Document should still have no statement" - ); - - assert_document_integrity(&doc); - } -} diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 89516b23c..1d437aa44 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -1,57 +1,329 @@ -use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic}; -use pgt_text_size::{TextRange, TextSize}; - -use super::statement_identifier::{StatementId, StatementIdGenerator}; +use std::sync::Arc; -type StatementPos = (StatementId, TextRange); - -pub(crate) struct Document { - pub(crate) content: String, - pub(crate) version: i32, +use pgt_diagnostics::{Diagnostic, DiagnosticExt, serde::Diagnostic as SDiagnostic}; +use pgt_query_ext::diagnostics::SyntaxDiagnostic; +use pgt_text_size::{TextRange, TextSize}; - pub(super) diagnostics: Vec, - /// List of statements sorted by range.start() - pub(super) positions: Vec, +use super::{ + annotation::AnnotationStore, + pg_query::PgQueryStore, + sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, + statement_identifier::StatementId, + tree_sitter::TreeSitterStore, +}; - pub(super) id_generator: StatementIdGenerator, +pub struct Document { + content: String, + version: i32, + ranges: Vec, + diagnostics: Vec, + ast_db: PgQueryStore, + cst_db: TreeSitterStore, + #[allow(dead_code)] + annotation_db: AnnotationStore, } impl Document { - pub(crate) fn new(content: String, version: i32) -> Self { - let mut id_generator = StatementIdGenerator::new(); + pub fn new(content: String, version: i32) -> Document { + let cst_db = TreeSitterStore::new(); + let ast_db = PgQueryStore::new(); + let annotation_db = AnnotationStore::new(); let (ranges, diagnostics) = split_with_diagnostics(&content, None); - Self { - positions: ranges - .into_iter() - .map(|range| (id_generator.next(), range)) - .collect(), + Document { + ranges, + diagnostics, content, version, - diagnostics, - id_generator, + ast_db, + cst_db, + annotation_db, + } + } + + pub fn update_content(&mut self, content: String, version: i32) { + self.content = content; + self.version = version; + + let (ranges, diagnostics) = split_with_diagnostics(&self.content, None); + + self.ranges = ranges; + self.diagnostics = diagnostics; + } + + pub fn get_document_content(&self) -> &str { + &self.content + } + + pub fn document_diagnostics(&self) -> &Vec { + &self.diagnostics + } + + pub fn find<'a, M>(&'a self, id: StatementId, mapper: M) -> Option + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, IdFilter::new(id)).next() + } + + pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, NoFilter> + where + M: StatementMapper<'a>, + { + self.iter_with_filter(mapper, NoFilter) + } + + pub fn iter_with_filter<'a, M, F>(&'a self, mapper: M, filter: F) -> ParseIterator<'a, M, F> + where + M: StatementMapper<'a>, + F: StatementFilter<'a>, + { + ParseIterator::new(self, mapper, filter) + } + + #[allow(dead_code)] + pub fn count(&self) -> usize { + self.iter(DefaultMapper).count() + } +} + +pub trait StatementMapper<'a> { + type Output; + + fn map(&self, parsed: &'a Document, id: StatementId, range: TextRange) -> Self::Output; +} + +pub trait StatementFilter<'a> { + fn predicate(&self, id: &StatementId, range: &TextRange, content: &str) -> bool; +} + +pub struct ParseIterator<'a, M, F> { + parser: &'a Document, + mapper: M, + filter: F, + ranges: std::slice::Iter<'a, TextRange>, + pending_sub_statements: Vec<(StatementId, TextRange, String)>, +} + +impl<'a, M, F> ParseIterator<'a, M, F> { + pub fn new(parser: &'a Document, mapper: M, filter: F) -> Self { + Self { + parser, + mapper, + filter, + ranges: parser.ranges.iter(), + pending_sub_statements: Vec::new(), } } +} + +impl<'a, M, F> Iterator for ParseIterator<'a, M, F> +where + M: StatementMapper<'a>, + F: StatementFilter<'a>, +{ + type Item = M::Output; + + fn next(&mut self) -> Option { + // First check if we have any pending sub-statements to process + if let Some((id, range, content)) = self.pending_sub_statements.pop() { + if self.filter.predicate(&id, &range, content.as_str()) { + return Some(self.mapper.map(self.parser, id, range)); + } + // If the sub-statement doesn't pass the filter, continue to the next item + return self.next(); + } + + // Process the next top-level statement + let next_range = self.ranges.next(); + + if let Some(range) = next_range { + // If we should include sub-statements and this statement has an AST + + let content = &self.parser.content[*range]; + let root_id = StatementId::new(content); - pub fn statement_content(&self, id: &StatementId) -> Option<&str> { - self.positions - .iter() - .find(|(statement_id, _)| statement_id == id) - .map(|(_, range)| &self.content[*range]) + if let Ok(ast) = self.parser.ast_db.get_or_cache_ast(&root_id).as_ref() { + // Check if this is a SQL function definition with a body + if let Some(sub_statement) = get_sql_fn_body(ast, &content) { + // Add sub-statements to our pending queue + self.pending_sub_statements.push(( + root_id.create_child(&sub_statement.body), + // adjust range to document + sub_statement.range + range.start(), + sub_statement.body.clone(), + )); + } + } + + // Return the current statement if it passes the filter + if self.filter.predicate(&root_id, &range, content) { + return Some(self.mapper.map(self.parser, root_id, *range)); + } + + // If the current statement doesn't pass the filter, try the next one + return self.next(); + } + + None } +} + +pub struct DefaultMapper; +impl<'a> StatementMapper<'a> for DefaultMapper { + type Output = (StatementId, TextRange, String); - /// Returns true if there is at least one fatal error in the diagnostics - /// - /// A fatal error is a scan error that prevents the document from being used - pub(super) fn has_fatal_error(&self) -> bool { - self.diagnostics - .iter() - .any(|d| d.severity() == Severity::Fatal) + fn map(&self, _parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + (id.clone(), range, id.content().to_string()) } +} + +pub struct ExecuteStatementMapper; +impl<'a> StatementMapper<'a> for ExecuteStatementMapper { + type Output = ( + StatementId, + TextRange, + String, + Option, + ); + + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let ast_result = parser.ast_db.get_or_cache_ast(&id); + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; - pub fn iter(&self) -> StatementIterator<'_> { - StatementIterator::new(self) + (id.clone(), range, id.content().to_string(), ast_option) + } +} + +pub struct AsyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { + type Output = ( + StatementId, + TextRange, + Option, + Arc, + Option, + ); + + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let ast_result = parser.ast_db.get_or_cache_ast(&id); + + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + let cst_result = parser.cst_db.get_or_cache_tree(&id); + + let sql_fn_sig = id.parent().and_then(|root| { + let ast_option = parser.ast_db.get_or_cache_ast(&root).as_ref().clone().ok(); + + let ast_option = ast_option.as_ref()?; + + get_sql_fn_signature(ast_option) + }); + + (id.clone(), range, ast_option, cst_result, sql_fn_sig) + } +} + +pub struct SyncDiagnosticsMapper; +impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { + type Output = ( + StatementId, + TextRange, + Option, + Option, + ); + + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let ast_result = parser.ast_db.get_or_cache_ast(&id); + + let (ast_option, diagnostics) = match &*ast_result { + Ok(node) => (Some(node.clone()), None), + Err(diag) => (None, Some(diag.clone())), + }; + + (id.clone(), range, ast_option, diagnostics) + } +} + +pub struct GetCompletionsMapper; +impl<'a> StatementMapper<'a> for GetCompletionsMapper { + type Output = (StatementId, TextRange, String, Arc); + + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let tree = parser.cst_db.get_or_cache_tree(&id); + (id.clone(), range, id.content().to_string(), tree) + } +} + +/* + * We allow an offset of two for the statement: + * + * select * from | <-- we want to suggest items for the next token. + * + * However, if the current statement is terminated by a semicolon, we don't apply any + * offset. + * + * select * from users; | <-- no autocompletions here. + */ +pub struct GetCompletionsFilter { + pub cursor_position: TextSize, +} +impl StatementFilter<'_> for GetCompletionsFilter { + fn predicate(&self, _id: &StatementId, range: &TextRange, content: &str) -> bool { + let is_terminated_by_semi = content.chars().last().is_some_and(|c| c == ';'); + + let measuring_range = if is_terminated_by_semi { + *range + } else { + range.checked_expand_end(2.into()).unwrap_or(*range) + }; + measuring_range.contains(self.cursor_position) + } +} + +pub struct NoFilter; +impl StatementFilter<'_> for NoFilter { + fn predicate(&self, _id: &StatementId, _range: &TextRange, _content: &str) -> bool { + true + } +} + +pub struct CursorPositionFilter { + pos: TextSize, +} + +impl CursorPositionFilter { + pub fn new(pos: TextSize) -> Self { + Self { pos } + } +} + +impl StatementFilter<'_> for CursorPositionFilter { + fn predicate(&self, _id: &StatementId, range: &TextRange, _content: &str) -> bool { + range.contains(self.pos) + } +} + +pub struct IdFilter { + id: StatementId, +} + +impl IdFilter { + pub fn new(id: StatementId) -> Self { + Self { id } + } +} + +impl StatementFilter<'_> for IdFilter { + fn predicate(&self, id: &StatementId, _range: &TextRange, _content: &str) -> bool { + *id == self.id } } @@ -79,29 +351,23 @@ pub(crate) fn split_with_diagnostics( ) } -pub struct StatementIterator<'a> { - document: &'a Document, - positions: std::slice::Iter<'a, StatementPos>, -} +#[cfg(test)] +mod tests { + use super::*; -impl<'a> StatementIterator<'a> { - pub fn new(document: &'a Document) -> Self { - Self { - document, - positions: document.positions.iter(), - } - } -} + #[test] + fn sql_function_body() { + let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; -impl<'a> Iterator for StatementIterator<'a> { - type Item = (StatementId, TextRange, &'a str); + let d = Document::new(input.to_string(), 1); - fn next(&mut self) -> Option { - self.positions.next().map(|(id, range)| { - let range = *range; - let doc = self.document; - let id = id.clone(); - (id, range, &doc.content[range]) - }) + let stmts = d.iter(DefaultMapper).collect::>(); + + assert_eq!(stmts.len(), 2); + assert_eq!(stmts[1].2, "select $1 + $2;"); } } diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs deleted file mode 100644 index 2b81faba9..000000000 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ /dev/null @@ -1,442 +0,0 @@ -use std::sync::Arc; - -use pgt_diagnostics::serde::Diagnostic as SDiagnostic; -use pgt_fs::PgTPath; -use pgt_query_ext::diagnostics::SyntaxDiagnostic; -use pgt_text_size::{TextRange, TextSize}; - -use crate::workspace::ChangeFileParams; - -use super::{ - annotation::AnnotationStore, - change::StatementChange, - document::{Document, StatementIterator}, - pg_query::PgQueryStore, - sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, - statement_identifier::StatementId, - tree_sitter::TreeSitterStore, -}; - -pub struct ParsedDocument { - #[allow(dead_code)] - path: PgTPath, - - doc: Document, - ast_db: PgQueryStore, - cst_db: TreeSitterStore, - annotation_db: AnnotationStore, -} - -impl ParsedDocument { - pub fn new(path: PgTPath, content: String, version: i32) -> ParsedDocument { - let doc = Document::new(content, version); - - let cst_db = TreeSitterStore::new(); - let ast_db = PgQueryStore::new(); - let annotation_db = AnnotationStore::new(); - - doc.iter().for_each(|(stmt, _, content)| { - cst_db.add_statement(&stmt, content); - }); - - ParsedDocument { - path, - doc, - ast_db, - cst_db, - annotation_db, - } - } - - /// Applies a change to the document and updates the CST and AST databases accordingly. - /// - /// Note that only tree-sitter cares about statement modifications vs remove + add. - /// Hence, we just clear the AST for the old statements and lazily load them when requested. - /// - /// * `params`: ChangeFileParams - The parameters for the change to be applied. - pub fn apply_change(&mut self, params: ChangeFileParams) { - for c in &self.doc.apply_file_change(¶ms) { - match c { - StatementChange::Added(added) => { - tracing::debug!( - "Adding statement: id:{:?}, text:{:?}", - added.stmt, - added.text - ); - self.cst_db.add_statement(&added.stmt, &added.text); - } - StatementChange::Deleted(s) => { - tracing::debug!("Deleting statement: id {:?}", s,); - self.cst_db.remove_statement(s); - self.ast_db.clear_statement(s); - self.annotation_db.clear_statement(s); - } - StatementChange::Modified(s) => { - tracing::debug!( - "Modifying statement with id {:?} (new id {:?}). Range {:?}, Changed from '{:?}' to '{:?}', changed text: {:?}", - s.old_stmt, - s.new_stmt, - s.change_range, - s.old_stmt_text, - s.new_stmt_text, - s.change_text - ); - - self.cst_db.modify_statement(s); - self.ast_db.clear_statement(&s.old_stmt); - self.annotation_db.clear_statement(&s.old_stmt); - } - } - } - } - - pub fn get_document_content(&self) -> &str { - &self.doc.content - } - - pub fn document_diagnostics(&self) -> &Vec { - &self.doc.diagnostics - } - - pub fn find<'a, M>(&'a self, id: StatementId, mapper: M) -> Option - where - M: StatementMapper<'a>, - { - self.iter_with_filter(mapper, IdFilter::new(id)).next() - } - - pub fn iter<'a, M>(&'a self, mapper: M) -> ParseIterator<'a, M, NoFilter> - where - M: StatementMapper<'a>, - { - self.iter_with_filter(mapper, NoFilter) - } - - pub fn iter_with_filter<'a, M, F>(&'a self, mapper: M, filter: F) -> ParseIterator<'a, M, F> - where - M: StatementMapper<'a>, - F: StatementFilter<'a>, - { - ParseIterator::new(self, mapper, filter) - } - - #[allow(dead_code)] - pub fn count(&self) -> usize { - self.iter(DefaultMapper).count() - } -} - -pub trait StatementMapper<'a> { - type Output; - - fn map( - &self, - parsed: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output; -} - -pub trait StatementFilter<'a> { - fn predicate(&self, id: &StatementId, range: &TextRange, content: &str) -> bool; -} - -pub struct ParseIterator<'a, M, F> { - parser: &'a ParsedDocument, - statements: StatementIterator<'a>, - mapper: M, - filter: F, - pending_sub_statements: Vec<(StatementId, TextRange, String)>, -} - -impl<'a, M, F> ParseIterator<'a, M, F> { - pub fn new(parser: &'a ParsedDocument, mapper: M, filter: F) -> Self { - Self { - parser, - statements: parser.doc.iter(), - mapper, - filter, - pending_sub_statements: Vec::new(), - } - } -} - -impl<'a, M, F> Iterator for ParseIterator<'a, M, F> -where - M: StatementMapper<'a>, - F: StatementFilter<'a>, -{ - type Item = M::Output; - - fn next(&mut self) -> Option { - // First check if we have any pending sub-statements to process - if let Some((id, range, content)) = self.pending_sub_statements.pop() { - if self.filter.predicate(&id, &range, content.as_str()) { - return Some(self.mapper.map(self.parser, id, range, &content)); - } - // If the sub-statement doesn't pass the filter, continue to the next item - return self.next(); - } - - // Process the next top-level statement - let next_statement = self.statements.next(); - - if let Some((root_id, range, content)) = next_statement { - // If we should include sub-statements and this statement has an AST - let content_owned = content.to_string(); - if let Ok(ast) = self - .parser - .ast_db - .get_or_cache_ast(&root_id, &content_owned) - .as_ref() - { - // Check if this is a SQL function definition with a body - if let Some(sub_statement) = get_sql_fn_body(ast, &content_owned) { - // Add sub-statements to our pending queue - self.pending_sub_statements.push(( - root_id.create_child(), - // adjust range to document - sub_statement.range + range.start(), - sub_statement.body.clone(), - )); - } - } - - // Return the current statement if it passes the filter - if self.filter.predicate(&root_id, &range, content) { - return Some(self.mapper.map(self.parser, root_id, range, content)); - } - - // If the current statement doesn't pass the filter, try the next one - return self.next(); - } - - None - } -} - -pub struct DefaultMapper; -impl<'a> StatementMapper<'a> for DefaultMapper { - type Output = (StatementId, TextRange, String); - - fn map( - &self, - _parser: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output { - (id, range, content.to_string()) - } -} - -pub struct ExecuteStatementMapper; -impl<'a> StatementMapper<'a> for ExecuteStatementMapper { - type Output = ( - StatementId, - TextRange, - String, - Option, - ); - - fn map( - &self, - parser: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output { - let ast_result = parser.ast_db.get_or_cache_ast(&id, content); - let ast_option = match &*ast_result { - Ok(node) => Some(node.clone()), - Err(_) => None, - }; - - (id, range, content.to_string(), ast_option) - } -} - -pub struct AsyncDiagnosticsMapper; -impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { - type Output = ( - StatementId, - TextRange, - String, - Option, - Arc, - Option, - ); - - fn map( - &self, - parser: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output { - let content_owned = content.to_string(); - let ast_result = parser.ast_db.get_or_cache_ast(&id, &content_owned); - - let ast_option = match &*ast_result { - Ok(node) => Some(node.clone()), - Err(_) => None, - }; - - let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned); - - let sql_fn_sig = id - .parent() - .and_then(|root| { - let c = parser.doc.statement_content(&root)?; - Some((root, c)) - }) - .and_then(|(root, c)| { - let ast_option = parser - .ast_db - .get_or_cache_ast(&root, c) - .as_ref() - .clone() - .ok(); - - let ast_option = ast_option.as_ref()?; - - get_sql_fn_signature(ast_option) - }); - - (id, range, content_owned, ast_option, cst_result, sql_fn_sig) - } -} - -pub struct SyncDiagnosticsMapper; -impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { - type Output = ( - StatementId, - TextRange, - Option, - Option, - ); - - fn map( - &self, - parser: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output { - let ast_result = parser.ast_db.get_or_cache_ast(&id, content); - - let (ast_option, diagnostics) = match &*ast_result { - Ok(node) => (Some(node.clone()), None), - Err(diag) => (None, Some(diag.clone())), - }; - - (id, range, ast_option, diagnostics) - } -} - -pub struct GetCompletionsMapper; -impl<'a> StatementMapper<'a> for GetCompletionsMapper { - type Output = (StatementId, TextRange, String, Arc); - - fn map( - &self, - parser: &'a ParsedDocument, - id: StatementId, - range: TextRange, - content: &str, - ) -> Self::Output { - let tree = parser.cst_db.get_or_cache_tree(&id, content); - (id, range, content.into(), tree) - } -} - -/* - * We allow an offset of two for the statement: - * - * select * from | <-- we want to suggest items for the next token. - * - * However, if the current statement is terminated by a semicolon, we don't apply any - * offset. - * - * select * from users; | <-- no autocompletions here. - */ -pub struct GetCompletionsFilter { - pub cursor_position: TextSize, -} -impl StatementFilter<'_> for GetCompletionsFilter { - fn predicate(&self, _id: &StatementId, range: &TextRange, content: &str) -> bool { - let is_terminated_by_semi = content.chars().last().is_some_and(|c| c == ';'); - - let measuring_range = if is_terminated_by_semi { - *range - } else { - range.checked_expand_end(2.into()).unwrap_or(*range) - }; - measuring_range.contains(self.cursor_position) - } -} - -pub struct NoFilter; -impl StatementFilter<'_> for NoFilter { - fn predicate(&self, _id: &StatementId, _range: &TextRange, _content: &str) -> bool { - true - } -} - -pub struct CursorPositionFilter { - pos: TextSize, -} - -impl CursorPositionFilter { - pub fn new(pos: TextSize) -> Self { - Self { pos } - } -} - -impl StatementFilter<'_> for CursorPositionFilter { - fn predicate(&self, _id: &StatementId, range: &TextRange, _content: &str) -> bool { - range.contains(self.pos) - } -} - -pub struct IdFilter { - id: StatementId, -} - -impl IdFilter { - pub fn new(id: StatementId) -> Self { - Self { id } - } -} - -impl StatementFilter<'_> for IdFilter { - fn predicate(&self, id: &StatementId, _range: &TextRange, _content: &str) -> bool { - *id == self.id - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use pgt_fs::PgTPath; - - #[test] - fn sql_function_body() { - let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer - AS 'select $1 + $2;' - LANGUAGE SQL - IMMUTABLE - RETURNS NULL ON NULL INPUT;"; - - let path = PgTPath::new("test.sql"); - - let d = ParsedDocument::new(path, input.to_string(), 0); - - let stmts = d.iter(DefaultMapper).collect::>(); - - assert_eq!(stmts.len(), 2); - assert_eq!(stmts[1].2, "select $1 + $2;"); - } -} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index e5c0cac8a..45af96e7d 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -17,22 +17,13 @@ impl PgQueryStore { pub fn get_or_cache_ast( &self, statement: &StatementId, - content: &str, ) -> Arc> { if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { return existing; } - let r = Arc::new(pgt_query_ext::parse(content).map_err(SyntaxDiagnostic::from)); + let r = Arc::new(pgt_query_ext::parse(statement.content()).map_err(SyntaxDiagnostic::from)); self.db.insert(statement.clone(), r.clone()); r } - - pub fn clear_statement(&self, id: &StatementId) { - self.db.remove(id); - - if let Some(child_id) = id.get_child_id() { - self.db.remove(&child_id); - } - } } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 627ff2618..9328671bf 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -1,24 +1,6 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] -pub struct RootId { - inner: usize, -} +use std::sync::Arc; -#[cfg(test)] -impl From for usize { - fn from(val: RootId) -> Self { - val.inner - } -} - -#[cfg(test)] -impl From for RootId { - fn from(inner: usize) -> Self { - RootId { inner } - } -} +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -35,91 +17,83 @@ impl From for RootId { /// ``` /// /// For now, we only support SQL functions – no complex, nested statements. -/// -/// An SQL function only ever has ONE child, that's why the inner `RootId` of a `Root` -/// is the same as the one of its `Child`. pub enum StatementId { - Root(RootId), - // StatementId is the same as the root id since we can only have a single sql function body per Root - Child(RootId), -} - -impl Default for StatementId { - fn default() -> Self { - StatementId::Root(RootId { inner: 0 }) - } + Root { + content: Arc, + }, + Child { + content: Arc, // child's actual content + parent_content: Arc, // parent's content for lookups + }, } impl StatementId { - pub fn raw(&self) -> usize { - match self { - StatementId::Root(s) => s.inner, - StatementId::Child(s) => s.inner, + pub fn new(statement: &str) -> Self { + StatementId::Root { + content: statement.into(), } } - pub fn is_root(&self) -> bool { - matches!(self, StatementId::Root(_)) - } - - pub fn is_child(&self) -> bool { - matches!(self, StatementId::Child(_)) + /// Creates a child statement ID with the given content and parent content. + pub fn new_child(child_content: &str, parent_content: &str) -> Self { + StatementId::Child { + content: child_content.into(), + parent_content: parent_content.into(), + } } - pub fn is_child_of(&self, maybe_parent: &StatementId) -> bool { + /// Use this if you need to create a matching `StatementId::Child` for `Root`. + /// You cannot create a `Child` of a `Child`. + /// Note: This method requires the child content to be provided. + pub fn create_child(&self, child_content: &str) -> StatementId { match self { - StatementId::Root(_) => false, - StatementId::Child(child_root) => match maybe_parent { - StatementId::Root(parent_rood) => child_root == parent_rood, - // TODO: can we have multiple nested statements? - StatementId::Child(_) => false, + StatementId::Root { content } => StatementId::Child { + content: child_content.into(), + parent_content: content.clone(), }, + StatementId::Child { .. } => panic!("Cannot create child from a child statement id"), } } - pub fn parent(&self) -> Option { + pub fn content(&self) -> &str { match self { - StatementId::Root(_) => None, - StatementId::Child(id) => Some(StatementId::Root(id.clone())), + StatementId::Root { content } => content, + StatementId::Child { content, .. } => content, } } -} -/// Helper struct to generate unique statement ids -pub struct StatementIdGenerator { - next_id: usize, -} + /// Returns the parent content if this is a child statement + pub fn parent_content(&self) -> Option<&str> { + match self { + StatementId::Root { .. } => None, + StatementId::Child { parent_content, .. } => Some(parent_content), + } + } -impl StatementIdGenerator { - pub fn new() -> Self { - Self { next_id: 0 } + pub fn is_root(&self) -> bool { + matches!(self, StatementId::Root { .. }) } - pub fn next(&mut self) -> StatementId { - let id = self.next_id; - self.next_id += 1; - StatementId::Root(RootId { inner: id }) + pub fn is_child(&self) -> bool { + matches!(self, StatementId::Child { .. }) } -} -impl StatementId { - /// Use this to get the matching `StatementId::Child` for - /// a `StatementId::Root`. - /// If the `StatementId` was already a `Child`, this will return `None`. - /// It is not guaranteed that the `Root` actually has a `Child` statement in the workspace. - pub fn get_child_id(&self) -> Option { + pub fn is_child_of(&self, maybe_parent: &StatementId) -> bool { match self { - StatementId::Root(id) => Some(StatementId::Child(RootId { inner: id.inner })), - StatementId::Child(_) => None, + StatementId::Root { .. } => false, + StatementId::Child { parent_content, .. } => match maybe_parent { + StatementId::Root { content } => parent_content == content, + StatementId::Child { .. } => false, + }, } } - /// Use this if you need to create a matching `StatementId::Child` for `Root`. - /// You cannot create a `Child` of a `Child`. - pub fn create_child(&self) -> StatementId { + pub fn parent(&self) -> Option { match self { - StatementId::Root(id) => StatementId::Child(RootId { inner: id.inner }), - StatementId::Child(_) => panic!("Cannot create child from a child statement id"), + StatementId::Root { .. } => None, + StatementId::Child { parent_content, .. } => Some(StatementId::Root { + content: parent_content.clone(), + }), } } } diff --git a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs index a89325356..2cd73133e 100644 --- a/crates/pgt_workspace/src/workspace/server/tree_sitter.rs +++ b/crates/pgt_workspace/src/workspace/server/tree_sitter.rs @@ -1,9 +1,8 @@ use std::sync::{Arc, Mutex}; use dashmap::DashMap; -use tree_sitter::InputEdit; -use super::{change::ModifiedStatement, statement_identifier::StatementId}; +use super::statement_identifier::StatementId; pub struct TreeSitterStore { db: DashMap>, @@ -23,139 +22,15 @@ impl TreeSitterStore { } } - pub fn get_or_cache_tree( - &self, - statement: &StatementId, - content: &str, - ) -> Arc { + pub fn get_or_cache_tree(&self, statement: &StatementId) -> Arc { if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { return existing; } let mut parser = self.parser.lock().expect("Failed to lock parser"); - let tree = Arc::new(parser.parse(content, None).unwrap()); + let tree = Arc::new(parser.parse(statement.content(), None).unwrap()); self.db.insert(statement.clone(), tree.clone()); tree } - - pub fn add_statement(&self, statement: &StatementId, content: &str) { - let mut parser = self.parser.lock().expect("Failed to lock parser"); - let tree = parser.parse(content, None).unwrap(); - self.db.insert(statement.clone(), Arc::new(tree)); - } - - pub fn remove_statement(&self, id: &StatementId) { - self.db.remove(id); - - if let Some(child_id) = id.get_child_id() { - self.db.remove(&child_id); - } - } - - pub fn modify_statement(&self, change: &ModifiedStatement) { - let old = self.db.remove(&change.old_stmt); - - if old.is_none() { - self.add_statement(&change.new_stmt, &change.change_text); - return; - } - - // we clone the three for now, lets see if that is sufficient or if we need to mutate the - // original tree instead but that will require some kind of locking - let mut tree = old.unwrap().1.as_ref().clone(); - - let edit = edit_from_change( - change.old_stmt_text.as_str(), - usize::from(change.change_range.start()), - usize::from(change.change_range.end()), - change.change_text.as_str(), - ); - - tree.edit(&edit); - - let mut parser = self.parser.lock().expect("Failed to lock parser"); - // todo handle error - self.db.insert( - change.new_stmt.clone(), - Arc::new(parser.parse(&change.new_stmt_text, Some(&tree)).unwrap()), - ); - } -} - -// Converts character positions and replacement text into a tree-sitter InputEdit -pub(crate) fn edit_from_change( - text: &str, - start_char: usize, - end_char: usize, - replacement_text: &str, -) -> InputEdit { - let mut start_byte = 0; - let mut end_byte = 0; - let mut chars_counted = 0; - - let mut line = 0; - let mut current_line_char_start = 0; // Track start of the current line in characters - let mut column_start = 0; - let mut column_end = 0; - - // Find the byte positions corresponding to the character positions - for (idx, c) in text.char_indices() { - if chars_counted == start_char { - start_byte = idx; - column_start = chars_counted - current_line_char_start; - } - if chars_counted == end_char { - end_byte = idx; - column_end = chars_counted - current_line_char_start; - break; // Found both start and end - } - if c == '\n' { - line += 1; - current_line_char_start = chars_counted + 1; // Next character starts a new line - } - chars_counted += 1; - } - - // Handle case where end_char is at the end of the text - if end_char == chars_counted && end_byte == 0 { - end_byte = text.len(); - column_end = chars_counted - current_line_char_start; - } - - let start_point = tree_sitter::Point::new(line, column_start); - let old_end_point = tree_sitter::Point::new(line, column_end); - - // Calculate the new end byte after the edit - let new_end_byte = start_byte + replacement_text.len(); - - // Calculate the new end position - let new_lines = replacement_text.matches('\n').count(); - let last_line_length = if new_lines > 0 { - replacement_text - .split('\n') - .next_back() - .unwrap_or("") - .chars() - .count() - } else { - replacement_text.chars().count() - }; - - let new_end_position = if new_lines > 0 { - // If there are new lines, the row is offset by the number of new lines, and the column is the length of the last line - tree_sitter::Point::new(start_point.row + new_lines, last_line_length) - } else { - // If there are no new lines, the row remains the same, and the column is offset by the length of the insertion - tree_sitter::Point::new(start_point.row, start_point.column + last_line_length) - }; - - InputEdit { - start_byte, - old_end_byte: end_byte, - new_end_byte, - start_position: start_point, - old_end_position: old_end_point, - new_end_position, - } } From ccd7e74103b8462319a310c223402bba2c91f527 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Jul 2025 17:44:45 +0200 Subject: [PATCH 2/9] progress --- crates/pgt_lsp/src/handlers/text_document.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/pgt_lsp/src/handlers/text_document.rs b/crates/pgt_lsp/src/handlers/text_document.rs index 318d95724..cc2efb4b3 100644 --- a/crates/pgt_lsp/src/handlers/text_document.rs +++ b/crates/pgt_lsp/src/handlers/text_document.rs @@ -45,10 +45,10 @@ pub(crate) async fn did_change( let url = params.text_document.uri; let version = params.text_document.version; - let biome_path = session.file_path(&url)?; + let pgt_path = session.file_path(&url)?; let old_text = session.workspace.get_file_content(GetFileContentParams { - path: biome_path.clone(), + path: pgt_path.clone(), })?; tracing::trace!("old document: {:?}", old_text); tracing::trace!("content changes: {:?}", params.content_changes); @@ -64,7 +64,7 @@ pub(crate) async fn did_change( session.insert_document(url.clone(), Document::new(version, &text)); session.workspace.change_file(ChangeFileParams { - path: biome_path, + path: pgt_path, version, content: text, })?; From c3429bb421b62ffd970843259f69a6b0fe4c59c9 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 11 Jul 2025 17:49:59 +0200 Subject: [PATCH 3/9] progress --- crates/pgt_workspace/src/features/completions.rs | 7 +------ crates/pgt_workspace/src/workspace/server/annotation.rs | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index 2803382ef..c6f05c6e2 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -76,7 +76,6 @@ pub(crate) fn get_statement_for_completions( #[cfg(test)] mod tests { - use pgt_fs::PgTPath; use pgt_text_size::TextSize; use crate::workspace::Document; @@ -93,11 +92,7 @@ mod tests { let pos: u32 = pos.try_into().unwrap(); ( - Document::new( - PgTPath::new("test.sql"), - sql.replace(CURSOR_POSITION, ""), - 5, - ), + Document::new(sql.replace(CURSOR_POSITION, ""), 5), TextSize::new(pos), ) } diff --git a/crates/pgt_workspace/src/workspace/server/annotation.rs b/crates/pgt_workspace/src/workspace/server/annotation.rs index f46700dbd..20710521b 100644 --- a/crates/pgt_workspace/src/workspace/server/annotation.rs +++ b/crates/pgt_workspace/src/workspace/server/annotation.rs @@ -76,8 +76,8 @@ mod tests { ("SELECT * FROM foo\n", false), ]; - for (idx, (content, expected)) in test_cases.iter().enumerate() { - let statement_id = StatementId::Root(idx.into()); + for (content, expected) in test_cases.iter() { + let statement_id = StatementId::new(content); let annotations = store.get_annotations(&statement_id, content); From ce0be2cf07c4e4f678207223612701a16ceff316 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 13:48:07 +0200 Subject: [PATCH 4/9] progress --- crates/pgt_workspace/src/workspace/server.rs | 8 ++++---- crates/pgt_workspace/src/workspace/server/document.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 1479d467b..b6c846f40 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -13,11 +13,11 @@ use document::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, Document, ExecuteStatementMapper, SyncDiagnosticsMapper, }; -use futures::{stream, StreamExt}; +use futures::{StreamExt, stream}; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{ - serde::Diagnostic as SDiagnostic, Diagnostic, DiagnosticExt, Error, Severity, + Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, }; use pgt_fs::{ConfigName, PgTPath}; use pgt_typecheck::{IdentifierType, TypecheckParams, TypedIdentifier}; @@ -26,17 +26,17 @@ use sqlx::{Executor, PgPool}; use tracing::{debug, info}; use crate::{ + WorkspaceError, configuration::to_analyser_rules, features::{ code_actions::{ self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory, ExecuteStatementParams, ExecuteStatementResult, }, - completions::{get_statement_for_completions, CompletionsResult, GetCompletionsParams}, + completions::{CompletionsResult, GetCompletionsParams, get_statement_for_completions}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, }, settings::{WorkspaceSettings, WorkspaceSettingsHandle, WorkspaceSettingsHandleMut}, - WorkspaceError, }; use super::{ diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 7111be165..65368d922 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use pgt_diagnostics::{serde::Diagnostic as SDiagnostic, Diagnostic, DiagnosticExt}; +use pgt_diagnostics::{Diagnostic, DiagnosticExt, serde::Diagnostic as SDiagnostic}; use pgt_query_ext::diagnostics::SyntaxDiagnostic; use pgt_suppressions::Suppressions; use pgt_text_size::{TextRange, TextSize}; @@ -8,7 +8,7 @@ use pgt_text_size::{TextRange, TextSize}; use super::{ annotation::AnnotationStore, pg_query::PgQueryStore, - sql_function::{get_sql_fn_body, get_sql_fn_signature, SQLFunctionSignature}, + sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, statement_identifier::StatementId, tree_sitter::TreeSitterStore, }; From f209a2d3d54a39aa97b0a216e1f3291e97ece902 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 14:00:02 +0200 Subject: [PATCH 5/9] progress --- .../benches/splitter.rs | 27 +++++++++++-------- .../src/workspace/server/document.rs | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/crates/pgt_statement_splitter/benches/splitter.rs b/crates/pgt_statement_splitter/benches/splitter.rs index 4a1cd7738..e7cdeeef6 100644 --- a/crates/pgt_statement_splitter/benches/splitter.rs +++ b/crates/pgt_statement_splitter/benches/splitter.rs @@ -2,8 +2,7 @@ use criterion::{Criterion, black_box, criterion_group, criterion_main}; use pgt_statement_splitter::split; pub fn splitter_benchmark(c: &mut Criterion) { - c.bench_function("large statement", |b| { - let statement = r#"with + let large_statement = r#"with available_tables as ( select c.relname as table_name, @@ -62,18 +61,24 @@ where "#; - let content = statement.repeat(500); + let large_content = large_statement.repeat(500); - b.iter(|| black_box(split(&content))); - }); + c.bench_function( + format!("large statement with length {}", large_content.len()).as_str(), + |b| { + b.iter(|| black_box(split(&large_content))); + }, + ); - c.bench_function("small statement", |b| { - let statement = r#"select 1 from public.user where id = 1"#; + let small_statement = r#"select 1 from public.user where id = 1"#; + let small_content = small_statement.repeat(500); - let content = statement.repeat(500); - - b.iter(|| black_box(split(&content))); - }); + c.bench_function( + format!("small statement with length {}", small_content.len()).as_str(), + |b| { + b.iter(|| black_box(split(&small_content))); + }, + ); } criterion_group!(benches, splitter_benchmark); diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 65368d922..1574b2348 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -155,7 +155,7 @@ where if let Ok(ast) = self.parser.ast_db.get_or_cache_ast(&root_id).as_ref() { // Check if this is a SQL function definition with a body - if let Some(sub_statement) = get_sql_fn_body(ast, &content) { + if let Some(sub_statement) = get_sql_fn_body(ast, content) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( root_id.create_child(&sub_statement.body), From af5061e8dbb7efb07cf96e122a19093c16deae32 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 14:04:27 +0200 Subject: [PATCH 6/9] progress --- Cargo.lock | 20 -------------------- Cargo.toml | 1 - crates/pgt_workspace/Cargo.toml | 1 - 3 files changed, 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cad64d65c..a89dbfbe9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,12 +1454,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1737,9 +1731,6 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" -dependencies = [ - "foldhash", -] [[package]] name = "hashlink" @@ -3099,7 +3090,6 @@ dependencies = [ "serde_json", "slotmap", "sqlx", - "string-interner", "strum", "tempfile", "tokio", @@ -4278,16 +4268,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "string-interner" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23de088478b31c349c9ba67816fa55d9355232d63c3afea8bf513e31f0f1d2c0" -dependencies = [ - "hashbrown 0.15.2", - "serde", -] - [[package]] name = "stringprep" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 6aa51b2f4..15c6f02ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,6 @@ serde_json = "1.0.114" similar = "2.6.0" slotmap = "1.0.7" smallvec = { version = "1.13.2", features = ["union", "const_new", "serde"] } -string-interner = "0.19.0" strum = { version = "0.27.1", features = ["derive"] } # this will use tokio if available, otherwise async-std convert_case = "0.6.0" diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index 27cc1653d..6b0cc0650 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -38,7 +38,6 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["raw_value"] } slotmap = { workspace = true, features = ["serde"] } sqlx.workspace = true -string-interner = { workspace = true, features = ["serde"] } strum = { workspace = true } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracing = { workspace = true, features = ["attributes", "log"] } From 64e79d149ac8310f8b909e825455f5c067e84fd4 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 14:06:04 +0200 Subject: [PATCH 7/9] progress --- crates/pgt_workspace/src/workspace/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index b6c846f40..e06d5adfa 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -637,9 +637,9 @@ impl Workspace for WorkspaceServer { }); tracing::debug!( - "Found {} completion items for statement with id {}", + "Found {} completion items for statement with id {:?}", items.len(), - id.content() + id ); Ok(CompletionsResult { items }) From b72c5d379842a573bf56d90b04f8ec8806e24d76 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 16:22:04 +0200 Subject: [PATCH 8/9] progress --- crates/pgt_workspace/src/workspace/server.rs | 2 +- crates/pgt_workspace/src/workspace/server/document.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index e06d5adfa..a3bccb9f3 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -451,7 +451,7 @@ impl Workspace for WorkspaceServer { if let Some(ast) = ast { pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: &id.content(), + sql: id.content(), ast: &ast, tree: &cst, schema_cache: schema_cache.as_ref(), diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 1574b2348..f8ab639d9 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -167,7 +167,7 @@ where } // Return the current statement if it passes the filter - if self.filter.predicate(&root_id, &range, content) { + if self.filter.predicate(&root_id, range, content) { return Some(self.mapper.map(self.parser, root_id, *range)); } From bcb879c86a0073faee5c27ee7c101c1565733109 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Sat, 12 Jul 2025 20:15:13 +0200 Subject: [PATCH 9/9] progress --- crates/pgt_lsp/src/capabilities.rs | 18 ++++++------------ .../pgt_workspace/src/features/code_actions.rs | 2 +- crates/pgt_workspace/src/workspace/server.rs | 8 +------- .../workspace/server/statement_identifier.rs | 15 +++++++-------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index b50f2753f..3b473eb73 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -1,4 +1,7 @@ use crate::adapters::{PositionEncoding, WideEncoding, negotiated_encoding}; +use crate::handlers::code_actions::command_id; +use pgt_workspace::features::code_actions::CommandActionCategory; +use strum::IntoEnumIterator; use tower_lsp::lsp_types::{ ClientCapabilities, CompletionOptions, ExecuteCommandOptions, PositionEncodingKind, SaveOptions, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, @@ -47,8 +50,9 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa }, }), execute_command_provider: Some(ExecuteCommandOptions { - commands: available_command_ids(), - + commands: CommandActionCategory::iter() + .map(|c| command_id(&c)) + .collect::>(), ..Default::default() }), document_formatting_provider: None, @@ -61,13 +65,3 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa ..Default::default() } } - -/// Returns all available command IDs for capability registration. -/// Since CommandActionCategory has variants with data, we can't use strum::IntoEnumIterator. -/// Instead, we manually list the command IDs we want to register. -fn available_command_ids() -> Vec { - vec![ - "postgres-tools.executeStatement".to_string(), - // Add other command IDs here as needed - ] -} diff --git a/crates/pgt_workspace/src/features/code_actions.rs b/crates/pgt_workspace/src/features/code_actions.rs index 025153588..22223dd3c 100644 --- a/crates/pgt_workspace/src/features/code_actions.rs +++ b/crates/pgt_workspace/src/features/code_actions.rs @@ -44,7 +44,7 @@ pub struct CommandAction { pub category: CommandActionCategory, } -#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize, strum::EnumIter)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] pub enum CommandActionCategory { ExecuteStatement(StatementId), diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index a3bccb9f3..81aa99ab4 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -626,7 +626,7 @@ impl Workspace for WorkspaceServer { tracing::debug!("No statement found."); Ok(CompletionsResult::default()) } - Some((id, range, content, cst)) => { + Some((_id, range, content, cst)) => { let position = params.position - range.start(); let items = pgt_completions::complete(pgt_completions::CompletionParams { @@ -636,12 +636,6 @@ impl Workspace for WorkspaceServer { text: content, }); - tracing::debug!( - "Found {} completion items for statement with id {:?}", - items.len(), - id - ); - Ok(CompletionsResult { items }) } } diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 9328671bf..592596902 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -27,6 +27,13 @@ pub enum StatementId { }, } +// this is only here for strum to work on the code actions enum +impl Default for StatementId { + fn default() -> Self { + StatementId::Root { content: "".into() } + } +} + impl StatementId { pub fn new(statement: &str) -> Self { StatementId::Root { @@ -34,14 +41,6 @@ impl StatementId { } } - /// Creates a child statement ID with the given content and parent content. - pub fn new_child(child_content: &str, parent_content: &str) -> Self { - StatementId::Child { - content: child_content.into(), - parent_content: parent_content.into(), - } - } - /// Use this if you need to create a matching `StatementId::Child` for `Root`. /// You cannot create a `Child` of a `Child`. /// Note: This method requires the child content to be provided.