Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make possible to specify subtype of SQL function #1160

Merged
merged 10 commits into from Feb 4, 2024
11 changes: 10 additions & 1 deletion src/context.rs
@@ -1,5 +1,6 @@
//! Code related to `sqlite3_context` common to `functions` and `vtab` modules.

use libsqlite3_sys::sqlite3_value;
use std::os::raw::{c_int, c_void};
#[cfg(feature = "array")]
use std::rc::Rc;
Expand All @@ -16,7 +17,11 @@ use crate::vtab::array::{free_array, ARRAY_TYPE};
// is often known to the compiler, and thus const prop/DCE can substantially
// simplify the function.
#[inline]
pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) {
pub(super) unsafe fn set_result(
ctx: *mut sqlite3_context,
args: &[*mut sqlite3_value],
result: &ToSqlOutput<'_>,
) {
let value = match *result {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => ValueRef::from(v),
Expand All @@ -26,6 +31,10 @@ pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<
// TODO sqlite3_result_zeroblob64 // 3.8.11
return ffi::sqlite3_result_zeroblob(ctx, len);
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(i) => {
return ffi::sqlite3_result_value(ctx, args[i]);
}
#[cfg(feature = "array")]
ToSqlOutput::Array(ref a) => {
return ffi::sqlite3_result_pointer(
Expand Down
158 changes: 115 additions & 43 deletions src/functions.rs
Expand Up @@ -66,7 +66,7 @@
use crate::ffi::sqlite3_value;

use crate::context::set_result;
use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};

use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};

Expand Down Expand Up @@ -149,6 +149,15 @@
unsafe { ValueRef::from_value(arg) }
}

/// Returns the `idx`th argument as a `SqlFnArg`.
/// To be used when the SQL function result is one of its arguments.
#[inline]
#[must_use]
pub fn get_arg(&self, idx: usize) -> SqlFnArg {
assert!(idx < self.len());
SqlFnArg { idx }
}

/// Returns the subtype of `idx`th argument.
///
/// # Failure
Expand Down Expand Up @@ -232,11 +241,6 @@
phantom: PhantomData,
})
}

/// Set the Subtype of an SQL function
pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
}
}

/// A reference to a connection handle with a lifetime bound to something.
Expand All @@ -258,6 +262,57 @@

type AuxInner = Arc<dyn Any + Send + Sync + 'static>;

/// Subtype of an SQL function
pub type SubType = Option<std::os::raw::c_uint>;

/// Result of an SQL function
pub trait SqlFnOutput {
/// Converts Rust value to SQLite value with an optional sub-type
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
}

impl<T: ToSql> SqlFnOutput for T {
#[inline]
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
ToSql::to_sql(self).map(|o| (o, None))
}
}

impl<T: ToSql> SqlFnOutput for (T, SubType) {
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
ToSql::to_sql(&self.0).map(|o| (o, self.1))
}
}

/// n-th arg of an SQL scalar function
pub struct SqlFnArg {
idx: usize,
}
impl ToSql for SqlFnArg {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::Arg(self.idx))
}
}

unsafe fn sql_result<T: SqlFnOutput>(
ctx: *mut sqlite3_context,
args: &[*mut sqlite3_value],
r: Result<T>,
) {
let t = r.as_ref().map(SqlFnOutput::to_sql);

match t {
Ok(Ok((ref value, sub_type))) => {
set_result(ctx, args, value);
if let Some(sub_type) = sub_type {
ffi::sqlite3_result_subtype(ctx, sub_type);
}
}
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),

Check warning on line 312 in src/functions.rs

View check run for this annotation

Codecov / codecov/patch

src/functions.rs#L311-L312

Added lines #L311 - L312 were not covered by tests
};
}

/// Aggregate is the callback interface for user-defined
/// aggregate function.
///
Expand All @@ -266,7 +321,7 @@
pub trait Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
T: SqlFnOutput,
{
/// Initializes the aggregation context. Will be called prior to the first
/// call to [`step()`](Aggregate::step) to set up the context for an
Expand Down Expand Up @@ -297,7 +352,7 @@
pub trait WindowAggregate<A, T>: Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
T: SqlFnOutput,
{
/// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context.
Expand Down Expand Up @@ -330,6 +385,8 @@
const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
/// Means that the function is unlikely to cause problems even if misused.
const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
/// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a sub-type to be associated with its result.
const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; // 3.45.0

Check warning on line 389 in src/functions.rs

View check run for this annotation

Codecov / codecov/patch

src/functions.rs#L388-L389

Added lines #L388 - L389 were not covered by tests
}
}

Expand Down Expand Up @@ -388,7 +445,7 @@
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand All @@ -412,7 +469,7 @@
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand All @@ -437,7 +494,7 @@
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
self.db
.borrow_mut()
Expand Down Expand Up @@ -470,23 +527,21 @@
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql,
T: SqlFnOutput,
{
unsafe extern "C" fn call_boxed_closure<F, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
F: FnMut(&Context<'_>) -> Result<T>,
T: ToSql,
T: SqlFnOutput,
{
let args = slice::from_raw_parts(argv, argc as usize);
let r = catch_unwind(|| {
let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
let ctx = Context { ctx, args };
(*boxed_f)(&ctx)
});
let t = match r {
Expand All @@ -496,13 +551,7 @@
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));

match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, args, t);
}

let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
Expand Down Expand Up @@ -533,7 +582,7 @@
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
Expand Down Expand Up @@ -564,7 +613,7 @@
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T> + 'static,
T: ToSql,
T: SqlFnOutput,
{
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
Expand Down Expand Up @@ -619,7 +668,7 @@
) where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
pac
Expand Down Expand Up @@ -667,7 +716,7 @@
) where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
pac
Expand Down Expand Up @@ -705,7 +754,7 @@
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
// Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
Expand Down Expand Up @@ -739,20 +788,15 @@
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, &[], t);
}

#[cfg(feature = "window")]
unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
T: SqlFnOutput,
{
// Within the xValue callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
Expand All @@ -776,12 +820,7 @@
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
sql_result(ctx, &[], t);
}

#[cfg(test)]
Expand All @@ -791,7 +830,7 @@

#[cfg(feature = "window")]
use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context, FunctionFlags};
use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
use crate::{Connection, Error, Result};

fn half(ctx: &Context<'_>) -> Result<c_double> {
Expand Down Expand Up @@ -1069,4 +1108,37 @@
assert_eq!(expected, results);
Ok(())
}

#[test]
fn test_sub_type() -> Result<()> {
fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
Ok(ctx.get_subtype(0) as i32)
}
fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
use std::os::raw::c_uint;
let value = ctx.get_arg(0);
let sub_type = ctx.get::<c_uint>(1)?;
Ok((value, Some(sub_type)))
}
let db = Connection::open_in_memory()?;
db.create_scalar_function(
"test_getsubtype",
1,
FunctionFlags::SQLITE_UTF8,
test_getsubtype,
)?;
db.create_scalar_function(
"test_setsubtype",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
test_setsubtype,
)?;
let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?;
assert_eq!(0, result);

let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?;
assert_eq!(123, result);

Ok(())
}
}
7 changes: 7 additions & 0 deletions src/pragma.rs
Expand Up @@ -70,6 +70,13 @@
Some(format!("Unsupported value \"{value:?}\"")),
));
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(_) => {
return Err(Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_MISUSE),
Some(format!("Unsupported value \"{value:?}\"")),
));

Check warning on line 78 in src/pragma.rs

View check run for this annotation

Codecov / codecov/patch

src/pragma.rs#L75-L78

Added lines #L75 - L78 were not covered by tests
}
#[cfg(feature = "array")]
ToSqlOutput::Array(_) => {
return Err(Error::SqliteFailure(
Expand Down
7 changes: 7 additions & 0 deletions src/statement.rs
Expand Up @@ -606,6 +606,13 @@
.conn
.decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) });
}
#[cfg(feature = "functions")]
ToSqlOutput::Arg(_) => {
return Err(Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_MISUSE),
Some(format!("Unsupported value \"{value:?}\"")),
));

Check warning on line 614 in src/statement.rs

View check run for this annotation

Codecov / codecov/patch

src/statement.rs#L611-L614

Added lines #L611 - L614 were not covered by tests
}
#[cfg(feature = "array")]
ToSqlOutput::Array(a) => {
return self.conn.decode_result(unsafe {
Expand Down