diff --git a/parser/parser.go b/parser/parser.go index 8cc4492c..8ebd85ad 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -30,6 +30,7 @@ uint64_t pg_query_hash_xxh3_64(void *data, size_t len, size_t seed) { import "C" import ( + "strings" "unsafe" ) @@ -180,6 +181,55 @@ func Normalize(input string) (result string, err error) { return } +func SplitWithScanner(input string, trimSpace bool) (result []string, err error) { + inputC := C.CString(input) + defer C.free(unsafe.Pointer(inputC)) + + resultC := C.pg_query_split_with_scanner(inputC) + defer C.pg_query_free_split_result(resultC) + + if resultC.error != nil { + err = newPgQueryError(resultC.error) + return + } + + result = handleSplitResult(input, trimSpace, resultC) + return +} + +func SplitWithParser(input string, trimSpace bool) (result []string, err error) { + inputC := C.CString(input) + defer C.free(unsafe.Pointer(inputC)) + + resultC := C.pg_query_split_with_parser(inputC) + defer C.pg_query_free_split_result(resultC) + + if resultC.error != nil { + err = newPgQueryError(resultC.error) + return + } + + result = handleSplitResult(input, trimSpace, resultC) + return +} + +func handleSplitResult(input string, trimSpace bool, resultC C.PgQuerySplitResult) (result []string) { + stmts := (**C.PgQuerySplitStmt)(unsafe.Pointer(resultC.stmts)) + for i := 0; i < int(resultC.n_stmts); i++ { + stmtptr := (**C.PgQuerySplitStmt)(unsafe.Pointer(uintptr(unsafe.Pointer(stmts)) + uintptr(i)*unsafe.Sizeof(*stmts))) + stmt := **stmtptr + + end := stmt.stmt_location + stmt.stmt_len + stmtStr := input[stmt.stmt_location:end] + if trimSpace { + stmtStr = strings.TrimSpace(stmtStr) + } + + result = append(result, stmtStr) + } + return +} + // FingerprintToUInt64 - Fingerprint the passed SQL statement using the C extension and returns result as uint64 func FingerprintToUInt64(input string) (result uint64, err error) { inputC := C.CString(input) diff --git a/pg_query.go b/pg_query.go index d787a8a4..e5ad6524 100644 --- a/pg_query.go +++ b/pg_query.go @@ -71,3 +71,11 @@ func FingerprintToUInt64(input string) (result uint64, err error) { func HashXXH3_64(input []byte, seed uint64) (result uint64) { return parser.HashXXH3_64(input, seed) } + +func SplitWithScanner(input string, trimSpace bool) (result []string, err error) { + return parser.SplitWithScanner(input, trimSpace) +} + +func SplitWithParser(input string, trimSpace bool) (result []string, err error) { + return parser.SplitWithParser(input, trimSpace) +} diff --git a/split_test.go b/split_test.go new file mode 100644 index 00000000..919f266c --- /dev/null +++ b/split_test.go @@ -0,0 +1,115 @@ +//go:build cgo +// +build cgo + +package pg_query_test + +import ( + "testing" + + pg_query "github.com/pganalyze/pg_query_go/v4" +) + +var splitTests = []struct { + name string + splitFunc func(string, bool) ([]string, error) + input string + trimSpace bool + expected []string +}{ + { + name: "splitWithParser - basic split", + splitFunc: pg_query.SplitWithParser, + input: "select * from a;select * from b;", + trimSpace: true, + expected: []string{ + "select * from a", + "select * from b", + }, + }, + { + name: "splitWithParser - procedure", + splitFunc: pg_query.SplitWithParser, + input: splitTestInput, + trimSpace: true, + expected: []string{ + splitExpected1, + splitExpected2, + }, + }, + { + name: "splitWithParser - basic split, no trim", + splitFunc: pg_query.SplitWithParser, + input: " select * from a;select * from b;", + trimSpace: false, + expected: []string{ + " select * from a", + "select * from b", + }, + }, + { + name: "splitWithScanner - basic split", + splitFunc: pg_query.SplitWithScanner, + input: "select * from a;select * from b;", + trimSpace: true, + expected: []string{ + "select * from a", + "select * from b", + }, + }, + { + name: "splitWithScanner - procedure", + splitFunc: pg_query.SplitWithScanner, + input: splitTestInput, + trimSpace: true, + expected: []string{ + splitExpected1, + splitExpected2, + }, + }, + { + name: "splitWithScanner - basic split, no trim", + splitFunc: pg_query.SplitWithScanner, + input: " select * from a;select * from b;", + trimSpace: false, + expected: []string{ + " select * from a", + "select * from b", + }, + }, +} + +var ( + splitTestInput = `UPDATE client SET name = "John Doe" WHERE id = 1; + +CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql; +` + splitExpected1 = `UPDATE client SET name = "John Doe" WHERE id = 1` + splitExpected2 = `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql` +) + +func TestSplit(t *testing.T) { + for _, test := range splitTests { + t.Run(test.name, func(t *testing.T) { + actuals, err := test.splitFunc(test.input, test.trimSpace) + if err != nil { + t.Error(err) + } + if len(actuals) != len(test.expected) { + t.Error("unexpected number of results") + } + for i, actual := range actuals { + if actual != test.expected[i] { + t.Errorf("expected: [%s], actual: [%s]", test.expected[i], actual) + } + } + }) + } +}