Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Allow custom window functions to be registered with the driver #1220

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 21 additions & 4 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,31 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Step(ctx, args)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Step(ctx, args)
}
}

//export inverseTrampoline
func inverseTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Inverse(ctx, args)
}
}

//export valueTrampoline
func valueTrampoline(ctx *C.sqlite3_context) {
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Value(ctx)
}
}

//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
ai.Done(ctx)
if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok {
ai.Done(ctx)
}
}

//export compareTrampoline
Expand Down
239 changes: 184 additions & 55 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,25 @@ int _sqlite3_create_function(
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal);
}

int _sqlite3_create_window_function(
sqlite3 *db,
const char *zFunctionName,
int nArg,
int eTextRep,
uintptr_t pApp,
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*),
void (*xValue)(sqlite3_context*),
void (*xInverse)(sqlite3_context*,int,sqlite3_value**)
) {
return sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xStep, xFinal, xValue, xInverse, 0);
}


void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void valueTrampoline(sqlite3_context*);
void inverseTrampoline(sqlite3_context*);
void doneTrampoline(sqlite3_context*);

int compareTrampoline(void*, int, char*, int, char*);
Expand Down Expand Up @@ -438,10 +455,18 @@ type aggInfo struct {
active map[int64]reflect.Value
next int64

nArgs int

stepArgConverters []callbackArgConverter
stepVariadicConverter callbackArgConverter

doneRetConverter callbackRetConverter

// Inverse and Value arg converters are used for window aggregations.
inverseArgConverters []callbackArgConverter
inverseVariadicConverter callbackArgConverter

valueRetConverter callbackRetConverter
}

func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
Expand All @@ -461,6 +486,8 @@ func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
return *aggIdx, ai.active[*aggIdx], nil
}

// Step Implements the xStep function for both aggregate and window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
Expand All @@ -481,6 +508,8 @@ func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
}
}

// Done Implements the xFinal function for both aggregate and window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
idx, agg, err := ai.agg(ctx)
if err != nil {
Expand All @@ -502,6 +531,49 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
}
}

// Inverse Implements the xInverse function for window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Inverse(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}

args, err := callbackConvertArgs(argv, ai.inverseArgConverters, ai.inverseVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}

ret := agg.MethodByName("Inverse").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}

// Value Implements the xValue function for window functions
// https://www.sqlite.org/windowfunctions.html#udfwinfunc
func (ai *aggInfo) Value(ctx *C.sqlite3_context) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Value").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}

err = ai.valueRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}

// Commit transaction.
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec(context.Background(), "COMMIT", nil)
Expand Down Expand Up @@ -684,20 +756,28 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xFunc), (*[0]byte)(xStep), (*[0]byte)(xFinal))
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function.
func sqlite3CreateWindowFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer, xValue unsafe.Pointer, xInverse unsafe.Pointer) C.int {
return C._sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xStep), (*[0]byte)(xFinal), (*[0]byte)(xValue), (*[0]byte)(xInverse))
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function or window function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// To register a window function, the type must also contain implement
// a Value and Inverse function.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
// type, or an interface that implements Step() and Done(), and optionally
// Value() and Inverse() if the aggregator is a window function.
//
// The constructor function and the Step/Done methods may optionally
// The constructor function and the Step/Done/Value/Inverse methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
Expand All @@ -719,93 +799,142 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error
}

agg := t.Out(0)
var implReturnsPointer bool
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
case reflect.Ptr:
implReturnsPointer = true
case reflect.Interface:
implReturnsPointer = false
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
return errors.New("SQLite aggregator constructor must return a pointer object")
}

stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
return errors.New("SQLite aggregator doesn't have a Step() function")
}
err := ai.setupStepInterface(stepFn, &ai.stepArgConverters, &ai.stepVariadicConverter, implReturnsPointer, "Step()")
if err != nil {
return err
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")

doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQLite aggregator doesn't have a Done() function")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
err = ai.setupDoneInterface(doneFn, &ai.doneRetConverter, implReturnsPointer, "Done()")
if err != nil {
return err
}

stepNArgs := step.NumIn()
valueFn, valueFnFound := agg.MethodByName("Value")
inverseFn, inverseFnFound := agg.MethodByName("Inverse")
if (inverseFnFound && !valueFnFound) || (valueFnFound && !inverseFnFound) {
return errors.New("SQLite window aggregator must implement both Value() and Inverse() functions")
}
isWindowFunction := valueFnFound && inverseFnFound
// Validate window function interface
if isWindowFunction {
if inverseFn.Type.NumIn() != stepFn.Type.NumIn() {
return errors.New("SQLite window aggregator Inverse() function must accept the same number of arguments as Step()")
}
err := ai.setupStepInterface(inverseFn, &ai.inverseArgConverters, &ai.inverseVariadicConverter, implReturnsPointer, "Inverse()")
if err != nil {
return err
}
err = ai.setupDoneInterface(valueFn, &ai.valueRetConverter, implReturnsPointer, "Value()")
if err != nil {
return err
}
}

ai.active = make(map[int64]reflect.Value)
ai.next = 1

// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
var rv C.int
if isWindowFunction {
rv = sqlite3CreateWindowFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), C.stepTrampoline, C.doneTrampoline, C.valueTrampoline, C.inverseTrampoline)
} else {
rv = sqlite3CreateFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
}
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}

func (ai *aggInfo) setupStepInterface(fn reflect.Method, argConverters *[]callbackArgConverter, variadicConverter *callbackArgConverter, isImplPointer bool, name string) error {
t := fn.Type
if t.NumOut() != 0 && t.NumOut() != 1 {
return fmt.Errorf("SQLite aggregator %s function must return 0 or 1 values", name)
}
if t.NumOut() == 1 && !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return fmt.Errorf("type of SQLite aggregator %s return value must be error", name)
}
nArgs := t.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
if isImplPointer {
// Skip over the method receiver
stepNArgs--
nArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
if t.IsVariadic() {
nArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
for i := start; i < start+nArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)

*argConverters = append(*argConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(step.In(start + stepNArgs).Elem())
if t.IsVariadic() {
conv, err := callbackArg(t.In(start + nArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
*variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
nArgs = -1
}
ai.nArgs = nArgs
return nil
}

doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
func (ai *aggInfo) setupDoneInterface(fn reflect.Method, retConverter *callbackRetConverter, implReturnsPointer bool, name string) error {
t := fn.Type
nArgs := t.NumIn()
if implReturnsPointer {
// Skip over the method receiver
doneNArgs--
nArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
if nArgs != 0 {
return fmt.Errorf("SQlite aggregator %s function must have no arguments", name)
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
if t.NumOut() != 1 && t.NumOut() != 2 {
return fmt.Errorf("SQLite aggregator %s function must return 1 or 2 values", name)
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return fmt.Errorf("second return value of SQLite aggregator %s function must be error", name)
}

conv, err := callbackRet(done.Out(0))
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1

// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
*retConverter = conv
return nil
}

Expand Down