Skip to content

Commit

Permalink
Merge pull request #1160 from gwenn/sub_type
Browse files Browse the repository at this point in the history
Make possible to specify subtype of SQL function
  • Loading branch information
gwenn committed Feb 4, 2024
2 parents 26eb784 + f48c578 commit 2812546
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 45 deletions.
11 changes: 10 additions & 1 deletion src/context.rs
@@ -1,5 +1,6 @@
//! Code related to `sqlite3_context` common to `functions` and `vtab` modules.

use libsqlite3_sys::sqlite3_value;
use std::os::raw::{c_int, c_void};
#[cfg(feature = "array")]
use std::rc::Rc;
Expand All @@ -16,7 +17,11 @@ use crate::vtab::array::{free_array, ARRAY_TYPE};
// is often known to the compiler, and thus const prop/DCE can substantially
// simplify the function.
#[inline]
pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) {
pub(super) unsafe fn set_result(
ctx: *mut sqlite3_context,
args: &[*mut sqlite3_value],
result: &ToSqlOutput<'_>,
) {
let value = match *result {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => ValueRef::from(v),
Expand All @@ -26,6 +31,10 @@ pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<
// TODO sqlite3_result_zeroblob64 // 3.8.11
return ffi::sqlite3_result_zeroblob(ctx, len);
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(i) => {
return ffi::sqlite3_result_value(ctx, args[i]);
}
#[cfg(feature = "array")]
ToSqlOutput::Array(ref a) => {
return ffi::sqlite3_result_pointer(
Expand Down
158 changes: 115 additions & 43 deletions src/functions.rs
Expand Up @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context;
use crate::ffi::sqlite3_value;

use crate::context::set_result;
use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};

use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};

Expand Down Expand Up @@ -149,6 +149,15 @@ impl Context<'_> {
unsafe { ValueRef::from_value(arg) }
}

/// Returns the `idx`th argument as a `SqlFnArg`.
/// To be used when the SQL function result is one of its arguments.
#[inline]
#[must_use]
pub fn get_arg(&self, idx: usize) -> SqlFnArg {
assert!(idx < self.len());
SqlFnArg { idx }
}

/// Returns the subtype of `idx`th argument.
///
/// # Failure
Expand Down Expand Up @@ -232,11 +241,6 @@ impl Context<'_> {
phantom: PhantomData,
})
}

/// Set the Subtype of an SQL function
pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
}
}

/// A reference to a connection handle with a lifetime bound to something.
Expand All @@ -258,6 +262,57 @@ impl Deref for ConnectionRef<'_> {

type AuxInner = Arc<dyn Any + Send + Sync + 'static>;

/// Subtype of an SQL function
pub type SubType = Option<std::os::raw::c_uint>;

/// Result of an SQL function
pub trait SqlFnOutput {
/// Converts Rust value to SQLite value with an optional sub-type
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
}

impl<T: ToSql> SqlFnOutput for T {
#[inline]
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
ToSql::to_sql(self).map(|o| (o, None))
}
}

impl<T: ToSql> SqlFnOutput for (T, SubType) {
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
ToSql::to_sql(&self.0).map(|o| (o, self.1))
}
}

/// n-th arg of an SQL scalar function
pub struct SqlFnArg {
idx: usize,
}
impl ToSql for SqlFnArg {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::Arg(self.idx))
}
}

unsafe fn sql_result<T: SqlFnOutput>(
ctx: *mut sqlite3_context,
args: &[*mut sqlite3_value],
r: Result<T>,
) {
let t = r.as_ref().map(SqlFnOutput::to_sql);

match t {
Ok(Ok((ref value, sub_type))) => {
set_result(ctx, args, value);
if let Some(sub_type) = sub_type {
ffi::sqlite3_result_subtype(ctx, sub_type);
}
}
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
};
}

/// Aggregate is the callback interface for user-defined
/// aggregate function.
///
Expand All @@ -266,7 +321,7 @@ type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
pub trait Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
T: SqlFnOutput,
{
/// Initializes the aggregation context. Will be called prior to the first
/// call to [`step()`](Aggregate::step) to set up the context for an
Expand Down Expand Up @@ -297,7 +352,7 @@ where
pub trait WindowAggregate<A, T>: Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
T: SqlFnOutput,
{
/// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context.
Expand Down Expand Up @@ -330,6 +385,8 @@ bitflags::bitflags! {
const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
/// Means that the function is unlikely to cause problems even if misused.
const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
/// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a sub-type to be associated with its result.
const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; // 3.45.0
}
}

Expand Down Expand Up @@ -388,7 +445,7 @@ impl Connection {
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand All @@ -412,7 +469,7 @@ impl Connection {
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand All @@ -437,7 +494,7 @@ impl Connection {
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand Down Expand Up @@ -470,23 +527,21 @@ impl InnerConnection {
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql,
T: SqlFnOutput,
{
unsafe extern "C" fn call_boxed_closure<F, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
F: FnMut(&Context<'_>) -> Result<T>,
T: ToSql,
T: SqlFnOutput,
{
let args = slice::from_raw_parts(argv, argc as usize);
let r = catch_unwind(|| {
let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
let ctx = Context { ctx, args };
(*boxed_f)(&ctx)
});
let t = match r {
Expand All @@ -496,13 +551,7 @@ impl InnerConnection {
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));

match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, args, t);
}

let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
Expand Down Expand Up @@ -533,7 +582,7 @@ impl InnerConnection {
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
Expand Down Expand Up @@ -564,7 +613,7 @@ impl InnerConnection {
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
Expand Down Expand Up @@ -619,7 +668,7 @@ unsafe extern "C" fn call_boxed_step<A, D, T>(
) where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
pac
Expand Down Expand Up @@ -667,7 +716,7 @@ unsafe extern "C" fn call_boxed_inverse<A, W, T>(
) where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
pac
Expand Down Expand Up @@ -705,7 +754,7 @@ unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
// Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
Expand Down Expand Up @@ -739,20 +788,15 @@ where
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, &[], t);
}

#[cfg(feature = "window")]
unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
// Within the xValue callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
Expand All @@ -776,12 +820,7 @@ where
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, &[], t);
}

#[cfg(test)]
Expand All @@ -791,7 +830,7 @@ mod test {

#[cfg(feature = "window")]
use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context, FunctionFlags};
use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
use crate::{Connection, Error, Result};

fn half(ctx: &Context<'_>) -> Result<c_double> {
Expand Down Expand Up @@ -1069,4 +1108,37 @@ mod test {
assert_eq!(expected, results);
Ok(())
}

#[test]
fn test_sub_type() -> Result<()> {
fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
Ok(ctx.get_subtype(0) as i32)
}
fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
use std::os::raw::c_uint;
let value = ctx.get_arg(0);
let sub_type = ctx.get::<c_uint>(1)?;
Ok((value, Some(sub_type)))
}
let db = Connection::open_in_memory()?;
db.create_scalar_function(
"test_getsubtype",
1,
FunctionFlags::SQLITE_UTF8,
test_getsubtype,
)?;
db.create_scalar_function(
"test_setsubtype",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
test_setsubtype,
)?;
let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?;
assert_eq!(0, result);

let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?;
assert_eq!(123, result);

Ok(())
}
}
7 changes: 7 additions & 0 deletions src/pragma.rs
Expand Up @@ -70,6 +70,13 @@ impl Sql {
Some(format!("Unsupported value \"{value:?}\"")),
));
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(_) => {
return Err(Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_MISUSE),
Some(format!("Unsupported value \"{value:?}\"")),
));
}
#[cfg(feature = "array")]
ToSqlOutput::Array(_) => {
return Err(Error::SqliteFailure(
Expand Down
7 changes: 7 additions & 0 deletions src/statement.rs
Expand Up @@ -606,6 +606,13 @@ impl Statement<'_> {
.conn
.decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) });
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(_) => {
return Err(Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_MISUSE),
Some(format!("Unsupported value \"{value:?}\"")),
));
}
#[cfg(feature = "array")]
ToSqlOutput::Array(a) => {
return self.conn.decode_result(unsafe {
Expand Down

0 comments on commit 2812546

Please sign in to comment.