Skip to content

Commit

Permalink
Merge 'Fix scalar API in extensions, add documentation and error hand…
Browse files Browse the repository at this point in the history
…ling' from Preston Thorpe

Closes #728
Changes the API to one macro/annotation on the relevant function
```rust
#[scalar(name = "uuid4_str", alias = "gen_random_uuid")]
fn uuid4_str(_args: &[Value]) -> Value {
    let uuid = uuid::Uuid::new_v4().to_string();
    Value::from_text(uuid)
}

register_extension! {
    scalars: { uuid4_str, uuid4 }
}
```
The only downside of this, is that for functions that use their
arguments, because this is not a trait, there is not really a way of
enforcing the function signature like there is with the other way.
Documentation has been added for this in the `scalar` macro, so
hopefully will not be an issue.
Also this PR cleans up the Aggregate API by changing the `args` and
`name` functions to constant associated types, as well as adds some
error handling and documentation.
```rust
impl AggFunc for Median {
    type State = Vec<f64>;
    const NAME: &'static str = "median";
    const ARGS: i32 = 1;

    fn step(state: &mut Self::State, args: &[Value]) {
        if let Some(val) = args.first().and_then(Value::to_float) {
            state.push(val);
        }
    }
//.. etc
```

Closes #735
  • Loading branch information
penberg committed Jan 19, 2025
2 parents 3e28541 + bcd3ae2 commit 0561ff1
Show file tree
Hide file tree
Showing 11 changed files with 706 additions and 604 deletions.
16 changes: 7 additions & 9 deletions core/ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{function::ExternalFunc, Database};
use limbo_ext::{
ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK,
};
use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction};
pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType};
use std::{
ffi::{c_char, c_void, CStr},
Expand All @@ -17,10 +15,10 @@ unsafe extern "C" fn register_scalar_function(
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return RESULT_ERROR,
Err(_) => return ResultCode::InvalidArgs,
};
if ctx.is_null() {
return RESULT_ERROR;
return ResultCode::Error;
}
let db = unsafe { &*(ctx as *const Database) };
db.register_scalar_function_impl(&name_str, func)
Expand All @@ -37,10 +35,10 @@ unsafe extern "C" fn register_aggregate_function(
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return RESULT_ERROR,
Err(_) => return ResultCode::InvalidArgs,
};
if ctx.is_null() {
return RESULT_ERROR;
return ResultCode::Error;
}
let db = unsafe { &*(ctx as *const Database) };
db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func))
Expand All @@ -52,7 +50,7 @@ impl Database {
name.to_string(),
Rc::new(ExternalFunc::new_scalar(name.to_string(), func)),
);
RESULT_OK
ResultCode::OK
}

fn register_aggregate_function_impl(
Expand All @@ -65,7 +63,7 @@ impl Database {
name.to_string(),
Rc::new(ExternalFunc::new_aggregate(name.to_string(), args, func)),
);
RESULT_OK
ResultCode::OK
}

pub fn build_limbo_ext(&self) -> ExtensionApi {
Expand Down
4 changes: 2 additions & 2 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use fallible_iterator::FallibleIterator;
#[cfg(not(target_family = "wasm"))]
use libloading::{Library, Symbol};
#[cfg(not(target_family = "wasm"))]
use limbo_ext::{ExtensionApi, ExtensionEntryPoint, RESULT_OK};
use limbo_ext::{ExtensionApi, ExtensionEntryPoint};
use log::trace;
use schema::Schema;
use sqlite3_parser::ast;
Expand Down Expand Up @@ -179,7 +179,7 @@ impl Database {
};
let api_ptr: *const ExtensionApi = Box::into_raw(api);
let result_code = unsafe { entry(api_ptr) };
if result_code == RESULT_OK {
if result_code.is_ok() {
self.syms.borrow_mut().extensions.push((lib, api_ptr));
Ok(())
} else {
Expand Down
2 changes: 1 addition & 1 deletion core/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl OwnedValue {
OwnedValue::Blob(std::rc::Rc::new(blob))
}
ExtValueType::Error => {
let Some(err) = v.to_text() else {
let Some(err) = v.to_error() else {
return OwnedValue::Null;
};
OwnedValue::Text(LimboText::new(Rc::new(err)))
Expand Down
59 changes: 25 additions & 34 deletions extensions/core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ like traditional `sqlite3` extensions, but are able to be written in much more e

## Currently supported features

- [ x ] **Scalar Functions**: Create scalar functions using the `ScalarDerive` derive macro and `Scalar` trait.
- [ x ] **Scalar Functions**: Create scalar functions using the `scalar` macro.
- [ x ] **Aggregate Functions**: Define aggregate functions with `AggregateDerive` macro and `AggFunc` trait.
- [] **Virtual tables**: TODO
---
Expand Down Expand Up @@ -37,41 +37,35 @@ Extensions can be registered with the `register_extension!` macro:
```rust

register_extension!{
scalars: { Double },
scalars: { double }, // name of your function, if different from attribute name
aggregates: { Percentile },
}
```

### Scalar Example:
```rust
use limbo_ext::{register_extension, Value, ScalarDerive, Scalar};

/// Annotate each with the ScalarDerive macro, and implement the Scalar trait on your struct
#[derive(ScalarDerive)]
struct Double;

impl Scalar for Double {
fn name(&self) -> &'static str { "double" }
fn call(&self, args: &[Value]) -> Value {
if let Some(arg) = args.first() {
match arg.value_type() {
ValueType::Float => {
let val = arg.to_float().unwrap();
Value::from_float(val * 2.0)
}
ValueType::Integer => {
let val = arg.to_integer().unwrap();
Value::from_integer(val * 2)
}
use limbo_ext::{register_extension, Value, scalar};

/// Annotate each with the scalar macro, specifying the name you would like to call it with
/// and optionally, an alias.. e.g. SELECT double(4); or SELECT twice(4);
#[scalar(name = "double", alias = "twice")]
fn double(&self, args: &[Value]) -> Value {
if let Some(arg) = args.first() {
match arg.value_type() {
ValueType::Float => {
let val = arg.to_float().unwrap();
Value::from_float(val * 2.0)
}
ValueType::Integer => {
let val = arg.to_integer().unwrap();
Value::from_integer(val * 2)
}
} else {
Value::null()
}
} else {
Value::null()
}
/// OPTIONAL: 'alias' if you would like to provide an additional name
fn alias(&self) -> &'static str { "twice" }
}

```

### Aggregates Example:

Expand All @@ -88,14 +82,11 @@ impl AggFunc for Percentile {

/// Define the name you wish to call your function by.
/// e.g. SELECT percentile(value, 40);
fn name(&self) -> &'static str {
"percentile"
}
const NAME: &str = "percentile";

/// Define the number of expected arguments for your function.
const ARGS: i32 = 2;

/// Define the number of arguments your function takes
fn args(&self) -> i32 {
2
}
/// Define a function called on each row/value in a relevant group/column
fn step(state: &mut Self::State, args: &[Value]) {
let (values, p_value, error) = state;
Expand Down Expand Up @@ -127,7 +118,7 @@ impl AggFunc for Percentile {
let (mut values, p_value, error) = state;

if let Some(error) = error {
return Value::error(error);
return Value::custom_error(error);
}

if values.is_empty() {
Expand Down
Loading

0 comments on commit 0561ff1

Please sign in to comment.