diff --git a/sqlite3.go b/sqlite3.go index ce985ec8..70706e2e 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -137,6 +137,57 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ } #endif +static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) { + const char *tail = NULL; + int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + if (tail == NULL) { + return rv; // NB: this should not happen + } + // Set oBytes to the number of bytes consumed instead of using the **pzTail + // out param since that requires storing a Go pointer in a C pointer, which + // is not allowed by CGO and will cause runtime.cgoCheckPointer to fail. + *oBytes = tail - zSql; + return rv; +} + +// _sqlite3_exec_no_args executes all of the statements in sql. None of the +// statements are allowed to have positional arguments. +int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, + int64_t *rowid, int64_t *changes) { + sqlite3_stmt *stmt; + const char *tail = NULL; + while (*zSql && nBytes > 0) { + stmt = NULL; + int rv = sqlite3_prepare_v2(db, zSql, nBytes, &stmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + + do { + rv = sqlite3_step(stmt); + } while (rv == SQLITE_ROW); + + *rowid = sqlite3_last_insert_rowid(db); + // We only record the number of changes made by the last statement. + *changes = sqlite3_changes64(db); + + if (rv != SQLITE_OK && rv != SQLITE_DONE) { + sqlite3_finalize(stmt); + return rv; + } + rv = sqlite3_finalize(stmt); + if (rv != SQLITE_OK) { + return rv; + } + nBytes -= tail - zSql; + zSql = tail; + } + return SQLITE_OK; +} + void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { sqlite3_result_text(ctx, s, -1, &free); } @@ -858,54 +909,122 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - start := 0 + // Trim the query. This is mostly important for getting rid + // of any trailing space. + query = strings.TrimSpace(query) + if len(args) > 0 { + return c.execArgs(ctx, query, args) + } + return c.execNoArgs(ctx, query) +} + +func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + var ( + stmtArgs []driver.NamedValue + start int + s SQLiteStmt // escapes to the heap so reuse it + sz C.int // number of query bytes consumed: escapes to the heap + ) for { - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + s = SQLiteStmt{c: c} // reset + sz = 0 + rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &s.s, &sz) + if rv != C.SQLITE_OK { + return nil, c.lastError() } + query = strings.TrimSpace(query[sz:]) + var res driver.Result - if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) + if s.s != nil { na := s.NumInput() if len(args)-start < na { - s.Close() + s.finalize() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current // statement and append all named arguments not // contained therein - if len(args[start:start+na]) > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 + if stmtArgs == nil { + stmtArgs = make([]driver.NamedValue, 0, na) + } + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) } } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + var err error + res, err = s.exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return nil, err } start += na } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { + s.finalize() + if len(query) == 0 { if res == nil { // https://github.com/mattn/go-sqlite3/issues/963 res = &SQLiteResult{0, 0} } return res, nil } - query = tail } } +// execNoArgsSync processes every SQL statement in query. All processing occurs +// in C code, which reduces the overhead of CGO calls. +func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) { + var rowid, changes C.int64_t + rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &rowid, &changes) + if rv != C.SQLITE_OK { + err = c.lastError() + } + return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err +} + +func (c *SQLiteConn) execNoArgs(ctx context.Context, query string) (driver.Result, error) { + done := ctx.Done() + if done == nil { + return c.execNoArgsSync(query) + } + + // Fast check if the Context is cancelled + if err := ctx.Err(); err != nil { + return nil, err + } + + ch := make(chan struct{}) + defer close(ch) + go func() { + select { + case <-done: + C.sqlite3_interrupt(c.db) + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns, which is + // why we can't check if only done is closed when waiting below. + <-ch + case <-ch: + } + }() + + res, err := c.execNoArgsSync(query) + + // Stop the goroutine and make sure we're at a point where + // sqlite3_interrupt cannot be called again. + ch <- struct{}{} + + if isInterruptErr(err) { + err = ctx.Err() + } + return res, err +} + // Query implements Queryer. func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { list := make([]driver.NamedValue, len(args)) @@ -1914,6 +2033,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) diff --git a/unsafe_go120.go b/unsafe_go120.go new file mode 100644 index 00000000..95d673ed --- /dev/null +++ b/unsafe_go120.go @@ -0,0 +1,17 @@ +//go:build !go1.21 +// +build !go1.21 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + b := *(*[]byte)(unsafe.Pointer(&s)) + return &b[0] + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +} diff --git a/unsafe_go121.go b/unsafe_go121.go new file mode 100644 index 00000000..b9c00a12 --- /dev/null +++ b/unsafe_go121.go @@ -0,0 +1,23 @@ +//go:build go1.21 +// +build go1.21 + +// The unsafe.StringData function was made available in Go 1.20 but it +// was not until Go 1.21 that Go was changed to interpret the Go version +// in go.mod (1.19 as of writing this) as the minimum version required +// instead of the exact version. +// +// See: https://github.com/golang/go/issues/59033 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + return unsafe.StringData(s) + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +}