diff --git a/func.go b/func.go index e080599..75421d9 100644 --- a/func.go +++ b/func.go @@ -72,6 +72,9 @@ func (ctx Context) ResultFloat(v float64) { C.sqlite3_result_double(ctx.ptr, C. func (ctx Context) ResultNull() { C.sqlite3_result_null(ctx.ptr) } func (ctx Context) ResultValue(v Value) { C.sqlite3_result_value(ctx.ptr, v.ptr) } func (ctx Context) ResultZeroBlob(n int64) { C.sqlite3_result_zeroblob64(ctx.ptr, C.sqlite3_uint64(n)) } +func (ctx Context) ResultBlob(v []byte) { + C.sqlite3_result_blob(ctx.ptr, C.CBytes(v), C.int(len(v)), (*[0]byte)(C.cfree)) +} func (ctx Context) ResultText(v string) { var cv *C.char if len(v) != 0 { diff --git a/func_test.go b/func_test.go index 3a7fb20..ddbf711 100644 --- a/func_test.go +++ b/func_test.go @@ -15,6 +15,9 @@ package sqlite_test import ( + "bytes" + "errors" + "strings" "testing" "crawshaw.io/sqlite" @@ -123,3 +126,114 @@ func TestAggFunc(t *testing.T) { } stmt.Finalize() } + +func TestBlobFunc(t *testing.T) { + c, err := sqlite.OpenConn(":memory:", 0) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := c.Close(); err != nil { + t.Error(err) + } + }() + + xFunc := func(ctx sqlite.Context, values ...sqlite.Value) { + var buf bytes.Buffer + for _, v := range values { + buf.Write(v.Blob()) + } + ctx.ResultBlob(buf.Bytes()) + } + if err := c.CreateFunction("blobcat", true, -1, xFunc, nil, nil); err != nil { + t.Fatal(err) + } + + stmt, _, err := c.PrepareTransient("SELECT blobcat(x'ff00',x'00ba');") + if err != nil { + t.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + t.Fatal(err) + } + got := make([]byte, 4) + want := []byte{0xFF, 0x00, 0x00, 0xBA} + if stmt.ColumnBytes(0, got) != len(want) || !bytes.Equal(got, want) { + t.Errorf("blobcat(x'ff00',x'00ba')='%x', want '%x'", got, want) + } + stmt.Finalize() +} + +func TestStringFunc(t *testing.T) { + c, err := sqlite.OpenConn(":memory:", 0) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := c.Close(); err != nil { + t.Error(err) + } + }() + + xFunc := func(ctx sqlite.Context, values ...sqlite.Value) { + var buf strings.Builder + for _, v := range values { + buf.WriteString(v.Text()) + } + ctx.ResultText(buf.String()) + } + if err := c.CreateFunction("strcat", true, -1, xFunc, nil, nil); err != nil { + t.Fatal(err) + } + + stmt, _, err := c.PrepareTransient("SELECT strcat('str','','cat');") + if err != nil { + t.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + t.Fatal(err) + } + if got, want := stmt.ColumnText(0), "strcat"; got != want { + t.Errorf("strcat('str','','cat')='%s', want '%s'", got, want) + } + stmt.Finalize() +} + +func TestErrorFunc(t *testing.T) { + c, err := sqlite.OpenConn(":memory:", 0) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := c.Close(); err != nil { + t.Error(err) + } + }() + + nilValueError := errors.New("nil value encountered") + xFunc := func(ctx sqlite.Context, values ...sqlite.Value) { + if values[0].Type() == sqlite.SQLITE_NULL { + ctx.ResultError(nilValueError) + } else { + ctx.ResultValue(values[0]) + } + } + + if err := c.CreateFunction("rejectnull", true, 1, xFunc, nil, nil); err != nil { + t.Fatal(err) + } + stmt, _, err := c.PrepareTransient("SELECT rejectnull(NULL);") + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Step() + if err == nil { + t.Fatal("rejectnull(NULL) failed to produce an error") + } + if sqlErr, ok := err.(sqlite.Error); !ok || sqlErr.Msg != nilValueError.Error() { + t.Fatal("Error does not match expected description") + } + + stmt.Finalize() +}