Skip to content

Commit

Permalink
Add PgTsVector type and implement deserialization
Browse files Browse the repository at this point in the history
This commit introduces the `PgTsVector`, `PgTsVectorEntry` types. We
also add an implementation of `FromSql<TsVector, Pg> for PgTsVector` for
deserialization.

Two new tests were also added to check deserializing tsvectors with and
without lexeme positions.
  • Loading branch information
encalypto committed Nov 4, 2024
1 parent 1d30bee commit 6110ae6
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repository = "https://github.com/diesel-rs/diesel_full_text_search"
edition = "2021"

[dependencies]
byteorder = "1.5.0"
diesel = { version = "~2.2.0", features = ["postgres_backend"], default-features = false }

[features]
Expand Down
176 changes: 175 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
mod types {
use diesel::sql_types::*;
use std::io::{BufRead, Cursor};

use byteorder::{NetworkEndian, ReadBytesExt};
use diesel::{deserialize::FromSql, pg::Pg, sql_types::*, Queryable};

#[derive(Clone, Copy, SqlType)]
#[diesel(postgres_type(oid = 3615, array_oid = 3645))]
Expand All @@ -18,6 +21,70 @@ mod types {
#[derive(SqlType)]
#[diesel(postgres_type(name = "regconfig"))]
pub struct RegConfig;

impl FromSql<TsVector, Pg> for PgTsVector {
fn from_sql(
bytes: <Pg as diesel::backend::Backend>::RawValue<'_>,
) -> diesel::deserialize::Result<Self> {
let mut cursor = Cursor::new(bytes.as_bytes());

// From Postgres `tsvector.c`:
//
// The binary format is as follows:
//
// uint32 number of lexemes
//
// for each lexeme:
// lexeme text in client encoding, null-terminated
// uint16 number of positions
// for each position:
// uint16 WordEntryPos

// Number of lexemes (uint32)
let num_lexemes = cursor.read_u32::<NetworkEndian>()?;

let mut entries = Vec::with_capacity(num_lexemes as usize);

for _ in 0..num_lexemes {
let mut lexeme = Vec::new();
cursor.read_until(0, &mut lexeme)?;
// Remove null terminator
lexeme.pop();
let lexeme = String::from_utf8(lexeme)?;

// Number of positions (uint16)
let num_positions = cursor.read_u16::<NetworkEndian>()?;

let mut positions = Vec::with_capacity(num_positions as usize);
for _ in 0..num_positions {
positions.push(cursor.read_u16::<NetworkEndian>()?);
}

entries.push(PgTsVectorEntry { lexeme, positions });
}

Ok(PgTsVector { entries })
}
}

impl Queryable<TsVector, Pg> for PgTsVector {
type Row = Self;

fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(row)
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct PgTsVector {
pub entries: Vec<PgTsVectorEntry>,
}

#[derive(Debug, Clone, PartialEq)]
pub struct PgTsVectorEntry {
pub lexeme: String,
pub positions: Vec<u16>,
}
}

pub mod configuration {
Expand Down Expand Up @@ -219,3 +286,110 @@ mod dsl {
pub use self::dsl::*;
pub use self::functions::*;
pub use self::types::*;

#[cfg(test)]
mod tests {
use super::*;

use diesel::dsl::sql;
use diesel::pg::PgConnection;
use diesel::prelude::*;

#[test]
fn test_tsvector_from_sql_with_positions() {
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let mut conn =
PgConnection::establish(&database_url).expect("Error connecting to database");

let query = diesel::select(sql::<TsVector>(
"to_tsvector('a fat cat sat on a mat and ate a fat rat')",
));
let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query");

let expected = PgTsVector {
entries: vec![
PgTsVectorEntry {
lexeme: "ate".to_owned(),
positions: vec![9],
},
PgTsVectorEntry {
lexeme: "cat".to_owned(),
positions: vec![3],
},
PgTsVectorEntry {
lexeme: "fat".to_owned(),
positions: vec![2, 11],
},
PgTsVectorEntry {
lexeme: "mat".to_owned(),
positions: vec![7],
},
PgTsVectorEntry {
lexeme: "rat".to_owned(),
positions: vec![12],
},
PgTsVectorEntry {
lexeme: "sat".to_owned(),
positions: vec![4],
},
],
};

assert_eq!(expected, result);
}

#[test]
fn test_tsvector_from_sql_without_positions() {
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let mut conn =
PgConnection::establish(&database_url).expect("Error connecting to database");

let query = diesel::select(sql::<TsVector>(
"'a fat cat sat on a mat and ate a fat rat'::tsvector",
));
let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query");

let expected = PgTsVector {
entries: vec![
PgTsVectorEntry {
lexeme: "a".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "and".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "ate".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "cat".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "fat".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "mat".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "on".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "rat".to_owned(),
positions: vec![],
},
PgTsVectorEntry {
lexeme: "sat".to_owned(),
positions: vec![],
},
],
};

assert_eq!(expected, result);
}
}

0 comments on commit 6110ae6

Please sign in to comment.