Skip to content

Commit

Permalink
detect columns used in generated columns to make subscriptions work s…
Browse files Browse the repository at this point in the history
…eamlessly
  • Loading branch information
jeromegn committed Aug 29, 2024
1 parent 82b0462 commit 93ed989
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
81 changes: 57 additions & 24 deletions crates/corro-types/src/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl SubsManager {
}
}

#[derive(Debug)]
struct MatchableChange<'a> {
table: &'a TableName,
pk: &'a [u8],
Expand Down Expand Up @@ -527,12 +528,14 @@ impl MatcherHandle {
candidates: &mut MatchCandidates,
change: MatchableChange,
) -> bool {
trace!("filtering change {change:?}");
// don't double process the same pk
if candidates
.get(change.table)
.map(|pks| pks.contains(change.pk))
.unwrap_or_default()
{
trace!("already contained key");
return false;
}

Expand All @@ -545,6 +548,7 @@ impl MatcherHandle {
.map(|cols| change.column.is_crsql_sentinel() || cols.contains(change.column.as_str()))
.unwrap_or_default()
{
trace!("could not match against parsed query table and columns");
return false;
}

Expand Down Expand Up @@ -767,7 +771,9 @@ impl Matcher {
pks.get(tbl_name)
.cloned()
.ok_or(MatcherError::MissingPrimaryKeys)?
.to_vec()
.into_iter()
.map(|pk| format!("coalesce({pk}, \"\")"))
.collect::<Vec<_>>()
.join(","),
);

Expand All @@ -793,6 +799,8 @@ impl Matcher {

let sql_hash = hex::encode(seahash::hash(sql.as_bytes()).to_be_bytes());

trace!("PARSED: {parsed:?}");

let handle = MatcherHandle {
inner: Arc::new(InnerMatcherHandle {
id,
Expand Down Expand Up @@ -1128,7 +1136,7 @@ impl Matcher {
PurgeOldChanges,
}

trace!("looping...");
// trace!("looping...");

let branch = tokio::select! {
biased;
Expand Down Expand Up @@ -1491,7 +1499,10 @@ impl Matcher {
for pks in pks {
tx.prepare_cached(&format!(
"INSERT INTO {tmp_table_name} VALUES ({})",
(0..pks.len()).map(|_i| "?").collect::<Vec<_>>().join(",")
(0..pks.len())
.map(|_i| "coalesce(?, \"\")")
.collect::<Vec<_>>()
.join(",")
))?
.execute(params_from_iter(pks))?;
}
Expand Down Expand Up @@ -1892,7 +1903,7 @@ fn extract_select_columns(select: &Select, schema: &Schema) -> Result<ParsedSele
let entry =
parsed.table_columns.entry(tbl_name.0.clone()).or_default();
for name in names.iter() {
entry.insert(name.0.clone());
insert_col(entry, schema, &tbl_name.0, &name.0);
}
}
}
Expand All @@ -1913,6 +1924,18 @@ fn extract_select_columns(select: &Select, schema: &Schema) -> Result<ParsedSele
Ok(parsed)
}

fn insert_col(set: &mut HashSet<String>, schema: &Schema, tbl_name: &str, name: &str) {
if let Some(generated) = schema
.tables
.get(tbl_name)
.and_then(|tbl| tbl.columns.get(name).and_then(|col| col.generated.as_ref()))
{
set.extend(generated.from.clone());
} else {
set.insert(name.to_owned());
}
}

fn extract_expr_columns(
expr: &Expr,
schema: &Schema,
Expand All @@ -1923,21 +1946,29 @@ fn extract_expr_columns(
Expr::Qualified(tblname, colname) => {
let resolved_name = parsed.aliases.get(&tblname.0).unwrap_or(&tblname.0);
// println!("adding column: {resolved_name} => {colname:?}");
parsed
.table_columns
.entry(resolved_name.clone())
.or_default()
.insert(colname.0.clone());
insert_col(
parsed
.table_columns
.entry(resolved_name.clone())
.or_default(),
schema,
resolved_name,
&colname.0,
);
}
// simplest case but also mentioning the schema
Expr::DoublyQualified(schema_name, tblname, colname) if schema_name.0 == "main" => {
let resolved_name = parsed.aliases.get(&tblname.0).unwrap_or(&tblname.0);
// println!("adding column: {resolved_name} => {colname:?}");
parsed
.table_columns
.entry(resolved_name.clone())
.or_default()
.insert(colname.0.clone());
insert_col(
parsed
.table_columns
.entry(resolved_name.clone())
.or_default(),
schema,
resolved_name,
&colname.0,
);
}

Expr::Name(colname) => {
Expand All @@ -1958,11 +1989,12 @@ fn extract_expr_columns(
}

if let Some(found) = found {
parsed
.table_columns
.entry(found.to_owned())
.or_default()
.insert(check_col_name);
insert_col(
parsed.table_columns.entry(found.to_owned()).or_default(),
schema,
found,
&check_col_name,
);
} else {
return Err(MatcherError::TableForColumnNotFound {
col_name: check_col_name,
Expand All @@ -1988,11 +2020,12 @@ fn extract_expr_columns(
}

if let Some(found) = found {
parsed
.table_columns
.entry(found.to_owned())
.or_default()
.insert(colname.0.clone());
insert_col(
parsed.table_columns.entry(found.to_owned()).or_default(),
schema,
found,
&colname.0,
);
} else {
if colname.0.starts_with('"') {
return Ok(());
Expand Down
36 changes: 33 additions & 3 deletions crates/corro-types/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ pub struct Column {
pub sql_type: (SqliteType, Option<String>),
pub nullable: bool,
pub default_value: Option<String>,
pub generated: Option<String>,
pub generated: Option<Generated>,
pub primary_key: bool,
pub raw: ColumnDefinition,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Generated {
pub raw: String,
pub from: Vec<String>,
}

impl Column {
pub fn sql_type(&self) -> (SqliteType, Option<&str>) {
(self.sql_type.0, self.sql_type.1.as_deref())
Expand Down Expand Up @@ -830,8 +836,16 @@ fn prepare_table(
nullable,
default_value,
generated: def.constraints.iter().find_map(|named| {
if let ColumnConstraint::Generated { ref expr, .. } = named.constraint {
Some(expr.to_string())
if let ColumnConstraint::Generated { ref expr, ref typ } =
named.constraint
{
trace!("generated typ: {typ:?}, expr: {expr:?}");
let mut from = vec![];
extract_expr_columns(expr, &mut from);
Some(Generated {
raw: expr.to_string(),
from,
})
} else {
None
}
Expand All @@ -849,3 +863,19 @@ fn prepare_table(
},
}
}

fn extract_expr_columns(expr: &Expr, cols: &mut Vec<String>) {
match expr {
Expr::FunctionCall {
args: Some(args), ..
} => {
for expr in args.iter() {
extract_expr_columns(expr, cols);
}
}
Expr::Id(colname) => {
cols.push(unquote(&colname.0).ok().unwrap_or(colname.0.clone()));
}
_ => {}
}
}

0 comments on commit 93ed989

Please sign in to comment.