diff --git a/Cargo.lock b/Cargo.lock index a187c97..4e75964 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4254,9 +4254,13 @@ name = "vectorize-proxy" version = "0.1.0" dependencies = [ "anyhow", + "clap", "log", "pgwire", + "rand 0.8.6", "regex", + "reqwest 0.12.28", + "serde", "serde_json", "sqlx", "thiserror 2.0.18", diff --git a/core/src/query.rs b/core/src/query.rs index 20afdb0..fdea0c0 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -665,13 +665,11 @@ pub fn join_table_cosine_similarity( .collect::>() .join(","); - let mut bind_value_counter: i16 = 2; // Start at $2 since $1 is the vector let mut where_filter = "WHERE 1=1".to_string(); - for (column, filter_value) in filters.iter() { + for (bind_value_counter, (column, filter_value)) in (2_i16..).zip(filters.iter()) { let operator = filter_value.operator.to_sql(); let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); where_filter.push_str(&filt); - bind_value_counter += 1; } let inner_query = format!( @@ -701,39 +699,34 @@ pub fn join_table_cosine_similarity( ) } +fn build_where_filter(filters: &BTreeMap) -> String { + let mut where_filter = "WHERE 1=1".to_string(); + for (bind_value_counter, (column, filter_value)) in (3_i16..).zip(filters.iter()) { + let operator = filter_value.operator.to_sql(); + let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); + where_filter.push_str(&filt); + } + where_filter +} + +/// Generates the core hybrid search SELECT that returns raw table rows. +/// `$1::vector` and `$2` are sqlx bind parameter placeholders for the embedding and query text. #[allow(clippy::too_many_arguments)] -pub fn hybrid_search_query( +fn hybrid_search_rows_sql( job_name: &str, src_schema: &str, src_table: &str, join_key: &str, - return_columns: &[String], + cols: &str, window_size: i32, limit: i32, rrf_k: f32, semantic_weight: f32, fts_weight: f32, - filters: &BTreeMap, + where_filter: &str, ) -> String { - let cols = &return_columns - .iter() - .map(|s| format!("t0.{s}")) - .collect::>() - .join(","); - - let mut bind_value_counter: i16 = 3; - let mut where_filter = "WHERE 1=1".to_string(); - for (column, filter_value) in filters.iter() { - let operator = filter_value.operator.to_sql(); - let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); - where_filter.push_str(&filt); - bind_value_counter += 1; - } - format!( " - SELECT to_jsonb(t) as results - FROM ( SELECT {cols}, t.rrf_score, t.semantic_rank, t.fts_rank, t.similarity_score FROM ( SELECT @@ -779,10 +772,89 @@ pub fn hybrid_search_query( INNER JOIN {src_schema}.{src_table} t0 ON t0.{join_key} = t.{join_key} {where_filter} ORDER BY t.rrf_score DESC - LIMIT {limit} + LIMIT {limit}" + ) +} + +/// Hybrid search returning each result row wrapped in a `results` JSONB column. +/// Used by the HTTP server. +#[allow(clippy::too_many_arguments)] +pub fn hybrid_search_query( + job_name: &str, + src_schema: &str, + src_table: &str, + join_key: &str, + return_columns: &[String], + window_size: i32, + limit: i32, + rrf_k: f32, + semantic_weight: f32, + fts_weight: f32, + filters: &BTreeMap, +) -> String { + let cols = return_columns + .iter() + .map(|s| format!("t0.{s}")) + .collect::>() + .join(","); + let where_filter = build_where_filter(filters); + let inner = hybrid_search_rows_sql( + job_name, + src_schema, + src_table, + join_key, + &cols, + window_size, + limit, + rrf_k, + semantic_weight, + fts_weight, + &where_filter, + ); + format!( + " + SELECT to_jsonb(t) as results + FROM ({inner} ) t" ) } + +/// Hybrid search returning raw table columns (`t0.*` plus ranking scores). +/// Used by the SQL proxy so results arrive as a normal table, not JSON. +#[allow(clippy::too_many_arguments)] +pub fn hybrid_search_query_rows( + job_name: &str, + src_schema: &str, + src_table: &str, + join_key: &str, + return_columns: &[String], + window_size: i32, + limit: i32, + rrf_k: f32, + semantic_weight: f32, + fts_weight: f32, + filters: &BTreeMap, +) -> String { + let cols = return_columns + .iter() + .map(|s| format!("t0.{s}")) + .collect::>() + .join(","); + let where_filter = build_where_filter(filters); + hybrid_search_rows_sql( + job_name, + src_schema, + src_table, + join_key, + &cols, + window_size, + limit, + rrf_k, + semantic_weight, + fts_weight, + &where_filter, + ) +} #[cfg(test)] mod tests { use super::*; diff --git a/docker-compose.yml b/docker-compose.yml index a7e0d34..bba2bd3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,7 +10,7 @@ services: logging: *default-logging environment: POSTGRES_PASSWORD: postgres - image: pgvector/pgvector:0.8.1-pg18 + image: pgvector/pgvector:0.8.2-pg18 ports: - 5432:5432 healthcheck: diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 8af8103..91fff55 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -3,7 +3,16 @@ name = "vectorize-proxy" version = "0.1.0" edition = "2024" +[[bin]] +name = "vectorize-proxy" +path = "src/main.rs" + +[lib] +name = "vectorize_proxy" +path = "src/lib.rs" + [dependencies] +clap = { version = "4.0", features = ["derive", "env"] } vectorize_core = { package = "vectorize-core", path = "../core" } anyhow = { workspace = true } @@ -17,4 +26,9 @@ tracing = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } -pgwire = { version = "0.30", features = ["server-api-aws-lc-rs"] } \ No newline at end of file +pgwire = { version = "0.30", features = ["server-api-aws-lc-rs"] } + +[dev-dependencies] +rand = "0.8" +reqwest = { version = "0.12", features = ["json"] } +serde = { version = "1", features = ["derive"] } \ No newline at end of file diff --git a/proxy/README.md b/proxy/README.md new file mode 100644 index 0000000..e1a5fbc --- /dev/null +++ b/proxy/README.md @@ -0,0 +1,61 @@ +## SQL proxy + +The proxy gives you a SQL interface to `vectorize.search()` without installing the Postgres extension. It sits in front of Postgres, intercepts `vectorize.search()` calls, generates embeddings, rewrites the query as a hybrid (semantic + full-text) search, and returns results — all transparently over the Postgres wire protocol. Any SQL client that works with Postgres works with the proxy. + +Start Postgres and the embeddings server: + +```bash +docker compose up postgres vector-serve -d +``` + +Load the example dataset: + +```bash +psql postgres://postgres:postgres@localhost:5432/postgres -f server/sql/example.sql +``` + +In a second terminal, start the HTTP server. This is used to manage embedding jobs and generate the initial embeddings for existing rows: + +```bash +DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres \ + EMBEDDING_SVC_URL=http://localhost:3000/v1 \ + cargo run --bin vectorize-server +``` + +Initialize the table and create the embedding job: + +```bash +curl -X POST http://localhost:8080/api/v1/table -d '{ + "job_name": "my_job", + "src_table": "my_products", + "src_schema": "public", + "src_columns": ["product_name", "description"], + "primary_key": "product_id", + "update_time_col": "updated_at", + "model": "sentence-transformers/all-MiniLM-L6-v2" + }' -H "Content-Type: application/json" +``` + +In a third terminal, start the proxy. It listens on port 5433 by default: + +```bash +DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres \ + EMBEDDING_SVC_URL=http://localhost:3000/v1 \ + cargo run --bin vectorize-proxy +``` + +Search using SQL by connecting `psql` to the proxy port (5433): + +```bash +psql postgres://postgres:postgres@localhost:5433/postgres -c \ + "SELECT * FROM vectorize.search(job=>'my_job', query=>'camping backpack', num_results=>3);" +``` + +```text + results +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"price": 45.00, "fts_rank": 1, "rrf_score": 0.03278688524590164, "product_id": 6, "updated_at": "2026-05-12T14:37:26.610753+00:00", "description": "Storage solution for carrying personal items on ones back", "product_name": "Backpack", "semantic_rank": 1, "product_category": "accessories", "similarity_score": 0.6296013593673885} + {"price": 40.00, "fts_rank": null, "rrf_score": 0.016129032258064516, "product_id": 39, "updated_at": "2026-05-12T14:37:26.610753+00:00", "description": "Sling made of fabric or netting, suspended between two points for relaxation", "product_name": "Hammock", "semantic_rank": 2, "product_category": "outdoor", "similarity_score": 0.3789524291697087} + {"price": 10.99, "fts_rank": null, "rrf_score": 0.015873015873015872, "product_id": 12, "updated_at": "2026-05-12T14:37:26.610753+00:00", "description": "Insulated container for beverages on-the-go", "product_name": "Travel Mug", "semantic_rank": 3, "product_category": "kitchenware", "similarity_score": 0.35918538314991255} +(3 rows) +``` \ No newline at end of file diff --git a/proxy/src/embeddings.rs b/proxy/src/embeddings.rs index 0261577..b3a0715 100644 --- a/proxy/src/embeddings.rs +++ b/proxy/src/embeddings.rs @@ -1,9 +1,10 @@ use anyhow::Result; use regex::Regex; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use vectorize_core::errors::VectorizeError; +use vectorize_core::query::hybrid_search_query_rows; use vectorize_core::transformers::providers::{self, prepare_generic_embedding_request}; use vectorize_core::transformers::types::Inputs; use vectorize_core::types::VectorizeJob; @@ -63,6 +64,125 @@ impl JobMapEmbeddingProvider { } } +/// Represents a parsed vectorize.search() named-argument function call +#[derive(Debug, Clone)] +pub struct SearchCall { + pub job_name: String, + pub query: String, + pub num_results: i32, + pub full_match: String, + pub start_pos: usize, + pub end_pos: usize, +} + +/// Parses `vectorize.search(job=>'...', query=>'...')` calls from SQL. +/// Only named-argument syntax is supported. +pub fn parse_search_calls(sql: &str) -> Result> { + let mut calls = Vec::new(); + + let call_re = Regex::new(r"(?i)vectorize\.search\s*\(([^)]*)\)")?; + let job_re = Regex::new(r"(?i)job\s*=>\s*'((?:[^']|'')*)'")?; + let query_re = Regex::new(r"(?i)query\s*=>\s*'((?:[^']|'')*)'")?; + let num_results_re = Regex::new(r"(?i)(?:num_results|limit)\s*=>\s*(\d+)")?; + + for mat in call_re.find_iter(sql) { + let full_match = mat.as_str().to_string(); + let args_str = call_re + .captures(mat.as_str()) + .and_then(|c| c.get(1)) + .map(|m| m.as_str()) + .unwrap_or(""); + + let job_name = job_re + .captures(args_str) + .and_then(|c| c.get(1)) + .map(|m| m.as_str().replace("''", "'")) + .ok_or_else(|| anyhow::anyhow!("Missing 'job' parameter in vectorize.search()"))?; + + let query = query_re + .captures(args_str) + .and_then(|c| c.get(1)) + .map(|m| m.as_str().replace("''", "'")) + .ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter in vectorize.search()"))?; + + let num_results = num_results_re + .captures(args_str) + .and_then(|c| c.get(1)) + .and_then(|m| m.as_str().parse().ok()) + .unwrap_or(10i32); + + calls.push(SearchCall { + job_name, + query, + num_results, + full_match, + start_pos: mat.start(), + end_pos: mat.end(), + }); + } + + Ok(calls) +} + +/// Detects `vectorize.search()` calls in SQL and rewrites the entire query to the +/// underlying hybrid search SQL with the embedding vector inlined. +/// Returns `Ok(None)` if no search calls are found. +pub async fn rewrite_search_query( + sql: &str, + provider: &JobMapEmbeddingProvider, +) -> Result, VectorizeError> { + let search_calls = parse_search_calls(sql).map_err(|e| { + VectorizeError::EmbeddingGenerationFailed(format!("Failed to parse search calls: {e}")) + })?; + + if search_calls.is_empty() { + return Ok(None); + } + + // Handle the first call (the common case; multiple search calls in one query are unusual) + let call = &search_calls[0]; + + let vectorize_job = provider.jobmap.get(&call.job_name).ok_or_else(|| { + VectorizeError::JobNotFound(format!("Job '{}' not found in proxy cache", call.job_name)) + })?; + + let embeddings = provider + .generate_embeddings(&call.query, &call.job_name) + .await?; + let embedding_literal = format_embeddings_as_vector(&embeddings); + + let window_size = 5 * call.num_results; + let template_sql = hybrid_search_query_rows( + &call.job_name, // vectorize_job.job_name was cleared by mem::take in cache load + &vectorize_job.src_schema, + &vectorize_job.src_table, + &vectorize_job.primary_key, + &["*".to_string()], + window_size, + call.num_results, + 60.0, + 1.0, + 1.0, + &BTreeMap::new(), + ); + + // Inline the sqlx bind parameter placeholders with their actual values. + // $1::vector is the embedding; $2 is the raw text for the FTS plainto_tsquery. + // With no filters, these are the only two bind params in the generated SQL. + let escaped_query = call.query.replace('\'', "''"); + let query_literal = format!("'{escaped_query}'"); + let inlined_sql = template_sql + .replace("$1::vector", &embedding_literal) + .replace("$2", &query_literal); + + // Splice the subquery in place of `vectorize.search(...)`, keeping any outer + // SELECT column list, WHERE, ORDER BY, or LIMIT the caller wrote. + let subquery = format!("({inlined_sql}\n ) AS _vectorize_search"); + let mut rewritten = sql.to_string(); + rewritten.replace_range(call.start_pos..call.end_pos, &subquery); + Ok(Some(rewritten)) +} + pub fn parse_embed_calls(sql: &str) -> Result> { let mut calls = Vec::new(); @@ -241,4 +361,56 @@ mod tests { let calls = parse_embed_calls(sql).unwrap(); assert!(calls.is_empty()); } + + #[test] + fn test_parse_search_calls_basic() { + let sql = "SELECT * FROM vectorize.search(job=>'my_job', query=>'camping backpack')"; + let calls = parse_search_calls(sql).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].job_name, "my_job"); + assert_eq!(calls[0].query, "camping backpack"); + assert_eq!(calls[0].num_results, 10); + } + + #[test] + fn test_parse_search_calls_with_num_results() { + let sql = "SELECT * FROM vectorize.search(job=>'my_job', query=>'camping backpack', num_results=>5)"; + let calls = parse_search_calls(sql).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].num_results, 5); + } + + #[test] + fn test_parse_search_calls_with_limit_alias() { + let sql = + "SELECT * FROM vectorize.search(job=>'my_job', query=>'camping backpack', limit=>3)"; + let calls = parse_search_calls(sql).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].num_results, 3); + } + + #[test] + fn test_parse_search_calls_query_first() { + let sql = "SELECT * FROM vectorize.search(query=>'camping backpack', job=>'my_job')"; + let calls = parse_search_calls(sql).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].job_name, "my_job"); + assert_eq!(calls[0].query, "camping backpack"); + } + + #[test] + fn test_parse_search_calls_none() { + let sql = "SELECT * FROM products WHERE id = 1"; + let calls = parse_search_calls(sql).unwrap(); + assert!(calls.is_empty()); + } + + #[test] + fn test_parse_search_calls_escaped_quotes() { + let sql = "SELECT * FROM vectorize.search(job=>'it''s a job', query=>'o''malley''s bar')"; + let calls = parse_search_calls(sql).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].job_name, "it's a job"); + assert_eq!(calls[0].query, "o'malley's bar"); + } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs new file mode 100644 index 0000000..237b79a --- /dev/null +++ b/proxy/src/main.rs @@ -0,0 +1,91 @@ +use clap::Parser; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::net::ToSocketAddrs; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; +use tracing::info; +use url::Url; + +use vectorize_proxy::cache::{ + load_initial_job_cache, setup_job_change_notifications, start_cache_sync_listener, +}; +use vectorize_proxy::protocol::ProxyConfig; +use vectorize_proxy::proxy::run_proxy_loop; + +#[derive(Parser)] +#[command( + name = "vectorize-proxy", + about = "PostgreSQL wire protocol proxy that intercepts vectorize.search() and vectorize.embed() calls" +)] +struct Args { + #[arg( + long, + env = "DATABASE_URL", + default_value = "postgres://postgres:postgres@localhost:5432/postgres" + )] + database_url: String, + + #[arg(long, env = "VECTORIZE_PROXY_PORT", default_value_t = 5433)] + proxy_port: u16, + + #[arg(long, env = "VECTORIZE_PROXY_TIMEOUT", default_value_t = 30)] + timeout_secs: u64, + + #[arg(long, env = "DATABASE_POOL_MAX", default_value_t = 8)] + db_pool_max: u32, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_target(false).init(); + + let args = Args::parse(); + + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(args.db_pool_max) + .connect(&args.database_url) + .await?; + + setup_job_change_notifications(&pool) + .await + .map_err(|e| anyhow::anyhow!("{e}")) + .map_err(|e| anyhow::anyhow!("Failed to set up job change notifications: {e}"))?; + + let initial_cache = load_initial_job_cache(&pool).await?; + info!("Loaded {} jobs into proxy cache", initial_cache.len()); + + let url = Url::parse(&args.database_url)?; + let postgres_host = url.host_str().unwrap().to_string(); + let postgres_port = url.port().unwrap_or(5432); + let postgres_addr: SocketAddr = format!("{postgres_host}:{postgres_port}") + .to_socket_addrs()? + .next() + .ok_or_else(|| anyhow::anyhow!("Failed to resolve PostgreSQL host address"))?; + + let config = Arc::new(ProxyConfig { + postgres_addr, + timeout: Duration::from_secs(args.timeout_secs), + jobmap: Arc::new(RwLock::new(initial_cache)), + db_pool: pool, + prepared_statements: Arc::new(RwLock::new(HashMap::new())), + }); + + let listen_addr: SocketAddr = format!("0.0.0.0:{}", args.proxy_port).parse()?; + + info!("vectorize-proxy listening on {listen_addr}"); + info!("Forwarding to PostgreSQL at {postgres_addr}"); + + // Keep the job cache in sync with database changes via pg_notify. + let config_for_listener = Arc::clone(&config); + tokio::spawn(async move { + if let Err(e) = start_cache_sync_listener(config_for_listener).await { + tracing::error!("Cache sync listener failed: {e}"); + } + }); + + run_proxy_loop(config, listen_addr) + .await + .map_err(|e| anyhow::anyhow!("{e}")) +} diff --git a/proxy/src/message_parser.rs b/proxy/src/message_parser.rs index 71a6064..f154235 100644 --- a/proxy/src/message_parser.rs +++ b/proxy/src/message_parser.rs @@ -1,6 +1,6 @@ use crate::embeddings::{ - JobMapEmbeddingProvider, parse_embed_calls, resolve_prepared_embed_calls, - rewrite_query_with_embeddings, + JobMapEmbeddingProvider, parse_embed_calls, parse_search_calls, resolve_prepared_embed_calls, + rewrite_query_with_embeddings, rewrite_search_query, }; use log::info; use std::sync::Arc; @@ -133,6 +133,33 @@ pub async fn process_simple_query_message( if let Some(null_pos) = query_bytes.iter().position(|&b| b == 0) { let sql = String::from_utf8_lossy(&query_bytes[..null_pos]).to_string(); + // Check for vectorize.search() calls first — these fully replace the query. + if let Ok(search_calls) = parse_search_calls(&sql) + && !search_calls.is_empty() + { + let jobmap_read = config.jobmap.read().await; + let embedding_provider = JobMapEmbeddingProvider::new(Arc::new(jobmap_read.clone())); + drop(jobmap_read); + + match rewrite_search_query(&sql, &embedding_provider).await { + Ok(Some(rewritten_sql)) => { + let rewritten_message = create_query_message(&rewritten_sql); + let parsed = ParsedMessage { + message_type: QUERY_MESSAGE, + sql: Some(rewritten_sql), + has_embed_calls: true, + rewritten: true, + }; + return Some((rewritten_message, parsed)); + } + Ok(None) => {} + Err(e) => { + log::warn!("Failed to rewrite vectorize.search() query: {e}"); + } + } + } + + // Check for vectorize.embed() calls — these replace only the function call inline. if let Ok(embed_calls) = parse_embed_calls(&sql) && !embed_calls.is_empty() { @@ -190,6 +217,37 @@ pub async fn process_parse_message( if offset > query_start { let sql = String::from_utf8_lossy(&data[query_start..offset]).to_string(); + if let Ok(search_calls) = parse_search_calls(&sql) + && !search_calls.is_empty() + { + let jobmap_read = config.jobmap.read().await; + let embedding_provider = + JobMapEmbeddingProvider::new(Arc::new(jobmap_read.clone())); + drop(jobmap_read); + + match rewrite_search_query(&sql, &embedding_provider).await { + Ok(Some(rewritten_sql)) => { + let rewritten_message = create_parse_message_with_rewritten_query( + data, + query_start, + offset, + &rewritten_sql, + ); + let parsed = ParsedMessage { + message_type: PARSE_MESSAGE, + sql: Some(rewritten_sql), + has_embed_calls: true, + rewritten: true, + }; + return Some((rewritten_message, parsed)); + } + Ok(None) => {} + Err(e) => { + log::warn!("Failed to rewrite vectorize.search() in Parse: {e}"); + } + } + } + if let Ok(embed_calls) = parse_embed_calls(&sql) && !embed_calls.is_empty() { diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9e6e405..8cc1c2f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -137,6 +137,34 @@ where Ok(()) } +/// Runs the TCP accept loop, dispatching each client connection to a handler task. +/// Takes a pre-built config so callers (e.g. the standalone binary) can share it +/// with other subsystems like the cache sync listener. +pub async fn run_proxy_loop( + config: Arc, + listen_addr: SocketAddr, +) -> Result<(), Box> { + let listener = TcpListener::bind(listen_addr).await?; + + loop { + match listener.accept().await { + Ok((client_stream, client_addr)) => { + info!("New proxy connection from: {client_addr}"); + + let config = Arc::clone(&config); + tokio::spawn(async move { + if let Err(e) = handle_connection_with_timeout(client_stream, config).await { + error!("Proxy connection error from {client_addr}: {e}"); + } + }); + } + Err(e) => { + error!("Failed to accept proxy connection: {e}"); + } + } + } +} + pub async fn start_postgres_proxy( proxy_port: u16, database_url: String, @@ -168,23 +196,5 @@ pub async fn start_postgres_proxy( info!("Proxy listening on: {listen_addr}"); info!("Forwarding to PostgreSQL at: {postgres_addr}"); - let listener = TcpListener::bind(listen_addr).await?; - - loop { - match listener.accept().await { - Ok((client_stream, client_addr)) => { - info!("New proxy connection from: {client_addr}"); - - let config = Arc::clone(&config); - tokio::spawn(async move { - if let Err(e) = handle_connection_with_timeout(client_stream, config).await { - error!("Proxy connection error from {client_addr}: {e}"); - } - }); - } - Err(e) => { - error!("Failed to accept proxy connection: {e}"); - } - } - } + run_proxy_loop(config, listen_addr).await } diff --git a/proxy/tests/proxy.rs b/proxy/tests/proxy.rs new file mode 100644 index 0000000..b39df4b --- /dev/null +++ b/proxy/tests/proxy.rs @@ -0,0 +1,264 @@ +//! Integration tests for the vectorize-proxy. +//! +//! These tests assume all services are already running on their default ports: +//! - Postgres: localhost:5432 +//! - vector-serve: localhost:3000 +//! - vectorize-server: localhost:8080 +//! - vectorize-proxy: localhost:5433 +//! +//! Load the example dataset and create the job before running: +//! +//! psql postgres://postgres:postgres@localhost:5432/postgres \ +//! -f server/sql/example.sql +//! +//! curl -s -X POST http://localhost:8080/api/v1/table \ +//! -H "Content-Type: application/json" \ +//! -d '{"job_name":"my_job","src_table":"my_products","src_schema":"public", +//! "src_columns":["product_name","description"],"primary_key":"product_id", +//! "update_time_col":"updated_at","model":"sentence-transformers/all-MiniLM-L6-v2"}' +//! +//! Run with: cargo test --test proxy + +use sqlx::{Column, Row}; + +const PROXY_URL: &str = "postgresql://postgres:postgres@localhost:5433/postgres"; + +async fn connect() -> sqlx::PgPool { + sqlx::PgPool::connect(PROXY_URL) + .await + .expect("Failed to connect to vectorize-proxy at localhost:5433 — is the proxy running?") +} + +/// Non-vectorize queries should pass through unchanged. +#[tokio::test] +async fn test_passthrough() { + let pool = connect().await; + + let row = sqlx::query("SELECT 1 + 1 AS result") + .fetch_one(&pool) + .await + .expect("simple passthrough query failed"); + + let result: i32 = row.get("result"); + assert_eq!(result, 2); +} + +/// `SELECT *` returns real table columns, not JSON. +#[tokio::test] +async fn test_search_returns_table_rows() { + let pool = connect().await; + + let rows = sqlx::query( + "SELECT * FROM vectorize.search(job=>'my_job', query=>'camping backpack', num_results=>3)", + ) + .fetch_all(&pool) + .await + .expect("SELECT * search query failed"); + + assert!(!rows.is_empty(), "expected search results, got none"); + assert!( + rows.len() <= 3, + "expected at most 3 rows, got {}", + rows.len() + ); + + let col_names: Vec<&str> = rows[0].columns().iter().map(|c| c.name()).collect(); + assert!( + col_names.contains(&"product_id"), + "expected product_id column in results, got: {col_names:?}" + ); + assert!( + col_names.contains(&"product_name"), + "expected product_name column in results, got: {col_names:?}" + ); + assert!( + col_names.contains(&"rrf_score"), + "expected rrf_score column in results, got: {col_names:?}" + ); +} + +/// `SELECT product_name FROM vectorize.search(...)` returns only the requested column. +#[tokio::test] +async fn test_search_column_projection() { + let pool = connect().await; + + let rows = sqlx::query( + "SELECT product_name FROM vectorize.search(job=>'my_job', query=>'camping backpack', num_results=>3)", + ) + .fetch_all(&pool) + .await + .expect("column-projected search query failed"); + + assert!(!rows.is_empty(), "expected at least one result"); + + let col_names: Vec<&str> = rows[0].columns().iter().map(|c| c.name()).collect(); + assert!( + col_names.contains(&"product_name"), + "expected product_name column" + ); + assert!( + !col_names.contains(&"product_id"), + "product_id should not appear when only product_name is selected, got: {col_names:?}" + ); + assert!( + !col_names.contains(&"rrf_score"), + "rrf_score should not appear when only product_name is selected" + ); +} + +/// `num_results` limits the number of returned rows. +#[tokio::test] +async fn test_search_num_results_limit() { + let pool = connect().await; + + let rows_1 = sqlx::query( + "SELECT * FROM vectorize.search(job=>'my_job', query=>'backpack', num_results=>1)", + ) + .fetch_all(&pool) + .await + .expect("search with num_results=>1 failed"); + + assert_eq!(rows_1.len(), 1, "expected exactly 1 result"); + + let rows_5 = sqlx::query( + "SELECT * FROM vectorize.search(job=>'my_job', query=>'backpack', num_results=>5)", + ) + .fetch_all(&pool) + .await + .expect("search with num_results=>5 failed"); + + assert!( + rows_5.len() <= 5, + "expected at most 5 results, got {}", + rows_5.len() + ); + assert!( + rows_5.len() > 1, + "expected more than 1 result with num_results=>5" + ); +} + +/// The `limit` alias for `num_results` works the same way. +#[tokio::test] +async fn test_search_limit_alias() { + let pool = connect().await; + + let rows = + sqlx::query("SELECT * FROM vectorize.search(job=>'my_job', query=>'backpack', limit=>2)") + .fetch_all(&pool) + .await + .expect("search with limit=>2 failed"); + + assert!( + rows.len() <= 2, + "expected at most 2 results, got {}", + rows.len() + ); +} + +/// Semantic relevance: "writing utensil" should rank "Pencil" first. +#[tokio::test] +async fn test_search_relevance_ordering() { + let pool = connect().await; + + let rows = sqlx::query( + "SELECT product_name FROM vectorize.search(job=>'my_job', query=>'writing utensil', num_results=>5)", + ) + .fetch_all(&pool) + .await + .expect("relevance ordering search failed"); + + assert!(!rows.is_empty(), "expected at least one result"); + + let top_name: String = rows[0].get("product_name"); + assert_eq!( + top_name.to_lowercase(), + "pencil", + "expected 'Pencil' as the top result for 'writing utensil', got '{top_name}'" + ); +} + +/// Named arguments may appear in any order. +#[tokio::test] +async fn test_search_argument_order_independence() { + let pool = connect().await; + + let rows_a = sqlx::query( + "SELECT product_name FROM vectorize.search(job=>'my_job', query=>'backpack', num_results=>3)", + ) + .fetch_all(&pool) + .await + .expect("query-first ordering failed"); + + let rows_b = sqlx::query( + "SELECT product_name FROM vectorize.search(query=>'backpack', job=>'my_job', num_results=>3)", + ) + .fetch_all(&pool) + .await + .expect("job-first ordering failed"); + + assert_eq!( + rows_a.len(), + rows_b.len(), + "result count should be the same regardless of argument order" + ); + + let names_a: Vec = rows_a.iter().map(|r| r.get("product_name")).collect(); + let names_b: Vec = rows_b.iter().map(|r| r.get("product_name")).collect(); + assert_eq!( + names_a, names_b, + "results should be identical regardless of named-argument order" + ); +} + +/// An outer WHERE clause filters the search subquery results. +#[tokio::test] +async fn test_search_outer_where_clause() { + let pool = connect().await; + + // Search for "speaker" but filter to only electronics + let rows = sqlx::query( + "SELECT product_name, product_category \ + FROM vectorize.search(job=>'my_job', query=>'audio speaker', num_results=>5) \ + WHERE product_category = 'electronics'", + ) + .fetch_all(&pool) + .await + .expect("search with outer WHERE failed"); + + for row in &rows { + let category: String = row.get("product_category"); + assert_eq!( + category, "electronics", + "expected only electronics, got '{category}'" + ); + } +} + +/// An outer ORDER BY can override the default relevance ordering. +#[tokio::test] +async fn test_search_outer_order_by_price() { + let pool = connect().await; + + let rows = sqlx::query( + "SELECT product_name, price::float8 AS price \ + FROM vectorize.search(job=>'my_job', query=>'electronics gadget', num_results=>5) \ + ORDER BY price ASC", + ) + .fetch_all(&pool) + .await + .expect("search with outer ORDER BY failed"); + + assert!(rows.len() >= 2, "need at least 2 rows to verify ordering"); + + let prices: Vec = rows.iter().map(|r| r.get::("price")).collect(); + + for window in prices.windows(2) { + assert!( + window[0] <= window[1], + "prices should be ascending: {} > {}", + window[0], + window[1] + ); + } +}