diff --git a/sqlite3.go b/sqlite3.go index 6d6a30d6..91e428c8 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -30,7 +30,6 @@ package sqlite3 #endif #include #include -#include #ifdef __CYGWIN__ # include @@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change return rv; } -static const char * -_trim_leading_spaces(const char *str) { - if (str) { - while (isspace(*str)) { - str++; - } - } - return str; -} - #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY extern int _sqlite3_step_blocking(sqlite3_stmt *stmt); extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes); @@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); - if (pzTail) { - *pzTail = _trim_leading_spaces(*pzTail); - } - return rv; + return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); } #else @@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); - if (pzTail) { - *pzTail = _trim_leading_spaces(*pzTail); - } - return rv; + return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); } + #endif void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { @@ -951,46 +933,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name op := pquery // original pointer defer C.free(unsafe.Pointer(op)) - var stmtArgs []driver.NamedValue var tail *C.char - s := new(SQLiteStmt) // escapes to the heap so reuse it - start := 0 - for { - *s = SQLiteStmt{c: c, cls: true} // reset - rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) - if rv != C.SQLITE_OK { - return nil, c.lastError() + s := &SQLiteStmt{c: c, cls: true} + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() + } + if s.s == nil { + return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil + } + na := s.NumInput() + if n := len(args); n != na { + s.finalize() + if n < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } + return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args)) + } + rows, err := s.query(ctx, args) + if err != nil && err != driver.ErrSkip { + s.finalize() + return rows, err + } - na := s.NumInput() - if len(args)-start < na { - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) - } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs[:0], args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - rows, err := s.query(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.finalize() - return rows, err + // Consume the rest of the query + for pquery = tail; pquery != nil && *pquery != 0; pquery = tail { + var stmt *C.sqlite3_stmt + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail) + if rv != C.SQLITE_OK { + rows.Close() + return nil, c.lastError() } - start += na - if tail == nil || *tail == '\000' { - return rows, nil + if stmt != nil { + rows.Close() + return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN } - rows.Close() - s.finalize() - pquery = tail } + + return rows, nil } // Begin transaction. @@ -2044,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { return s.query(context.Background(), list) } -func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) { if err := s.bind(args); err != nil { return nil, err } diff --git a/sqlite3_test.go b/sqlite3_test.go index 2211ad1c..089cf21a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "net/url" "os" + "path/filepath" "reflect" "regexp" "runtime" @@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) { defer db.Close() _, err = db.Exec(` - create table foo (id integer); -- one comment - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); -- another comment + CREATE TABLE foo (id INTEGER); -- one comment + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); -- another comment `, 1, 2, 3) if err != nil { t.Error("Failed to call db.Exec:", err) } } -func TestQueryer(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) +func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3")) if err != nil { t.Fatal("Failed to open database:", err) } defer db.Close() - _, err = db.Exec(` - create table foo (id integer); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) + if seed { + if _, err := db.Exec(`create table foo (id integer);`); err != nil { + t.Fatal(err) + } + _, err := db.Exec(` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + `, 3, 2, 1) + if err != nil { + t.Fatal(err) + } } - _, err = db.Exec(` - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); - `, 3, 2, 1) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - rows, err := db.Query(` - select id from foo order by id; - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - defer rows.Close() - n := 0 - for rows.Next() { - var id int - err = rows.Scan(&id) + // Capture panic so tests can continue + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 32*1024) + n := runtime.Stack(buf, false) + t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n]) + } + }() + test(t, db) +} + +func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} { + var values []interface{} + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) if err != nil { - t.Error("Failed to db.Query:", err) + t.Fatal(err) } - if id != n+1 { - t.Error("Failed to db.Query: not matched results") + if rows == nil { + t.Fatal("nil rows") } - n = n + 1 + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + var v interface{} + if err := rows.Scan(&v); err != nil { + t.Fatal(err) + } + values = append(values, v) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + }) + return values +} + +func TestQuery(t *testing.T) { + queries := []struct { + query string + args []interface{} + }{ + {"SELECT id FROM foo ORDER BY id;", nil}, + {"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}}, + {"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}}, + + // Comments + {"SELECT id FROM foo ORDER BY id; -- comment", nil}, + {"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil}, + { + `-- FOO + SELECT id FROM foo ORDER BY id; -- BAR + /* BAZ */`, + nil, + }, } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) + want := []interface{}{ + int64(1), + int64(2), + int64(3), + } + for _, q := range queries { + t.Run("", func(t *testing.T) { + got := testQueryValues(t, q.query, q.args...) + if !reflect.DeepEqual(got, want) { + t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want) + } + }) } - if n != 3 { - t.Errorf("Expected 3 rows but retrieved %v", n) +} + +func TestQueryNoSQL(t *testing.T) { + got := testQueryValues(t, "") + if got != nil { + t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil) } } +func testQueryError(t *testing.T, query string, args ...interface{}) { + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) + if err == nil { + t.Error("Expected an error got:", err) + } + if rows != nil { + t.Error("Returned rows should be nil on error!") + // Attempt to iterate over rows to make sure they don't panic. + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + } + if err := rows.Err(); err != nil { + t.Error(err) + } + rows.Close() + } + }) +} + +func TestQueryNotEnoughArgs(t *testing.T) { + testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1) +} + +func TestQueryTooManyArgs(t *testing.T) { + // TODO: test error message / kind + testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2) +} + +func TestQueryMultipleStatements(t *testing.T) { + testQueryError(t, "SELECT 1; SELECT 2;") +} + +func TestQueryInvalidTable(t *testing.T) { + testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;") +} + func TestStress(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, {Name: "BenchmarkExecStep", F: benchmarkExecStep}, - {Name: "BenchmarkQueryStep", F: benchmarkQueryStep}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) { } } } - -func benchmarkQueryStep(b *testing.B) { - var i int - for n := 0; n < b.N; n++ { - if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil { - b.Fatal(err) - } - } -}