Skip to content

Commit

Permalink
Add parser.SplitWithScanner() and parser.SplitWithParser() (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoislarochelle authored Dec 2, 2023
1 parent 7972fca commit 3c8cb1b
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 0 deletions.
50 changes: 50 additions & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ uint64_t pg_query_hash_xxh3_64(void *data, size_t len, size_t seed) {
import "C"

import (
"strings"
"unsafe"
)

Expand Down Expand Up @@ -180,6 +181,55 @@ func Normalize(input string) (result string, err error) {
return
}

func SplitWithScanner(input string, trimSpace bool) (result []string, err error) {
inputC := C.CString(input)
defer C.free(unsafe.Pointer(inputC))

resultC := C.pg_query_split_with_scanner(inputC)
defer C.pg_query_free_split_result(resultC)

if resultC.error != nil {
err = newPgQueryError(resultC.error)
return
}

result = handleSplitResult(input, trimSpace, resultC)
return
}

func SplitWithParser(input string, trimSpace bool) (result []string, err error) {
inputC := C.CString(input)
defer C.free(unsafe.Pointer(inputC))

resultC := C.pg_query_split_with_parser(inputC)
defer C.pg_query_free_split_result(resultC)

if resultC.error != nil {
err = newPgQueryError(resultC.error)
return
}

result = handleSplitResult(input, trimSpace, resultC)
return
}

func handleSplitResult(input string, trimSpace bool, resultC C.PgQuerySplitResult) (result []string) {
stmts := (**C.PgQuerySplitStmt)(unsafe.Pointer(resultC.stmts))
for i := 0; i < int(resultC.n_stmts); i++ {
stmtptr := (**C.PgQuerySplitStmt)(unsafe.Pointer(uintptr(unsafe.Pointer(stmts)) + uintptr(i)*unsafe.Sizeof(*stmts)))
stmt := **stmtptr

end := stmt.stmt_location + stmt.stmt_len
stmtStr := input[stmt.stmt_location:end]
if trimSpace {
stmtStr = strings.TrimSpace(stmtStr)
}

result = append(result, stmtStr)
}
return
}

// FingerprintToUInt64 - Fingerprint the passed SQL statement using the C extension and returns result as uint64
func FingerprintToUInt64(input string) (result uint64, err error) {
inputC := C.CString(input)
Expand Down
8 changes: 8 additions & 0 deletions pg_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,11 @@ func FingerprintToUInt64(input string) (result uint64, err error) {
func HashXXH3_64(input []byte, seed uint64) (result uint64) {
return parser.HashXXH3_64(input, seed)
}

func SplitWithScanner(input string, trimSpace bool) (result []string, err error) {
return parser.SplitWithScanner(input, trimSpace)
}

func SplitWithParser(input string, trimSpace bool) (result []string, err error) {
return parser.SplitWithParser(input, trimSpace)
}
115 changes: 115 additions & 0 deletions split_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//go:build cgo
// +build cgo

package pg_query_test

import (
"testing"

pg_query "github.com/pganalyze/pg_query_go/v4"
)

var splitTests = []struct {
name string
splitFunc func(string, bool) ([]string, error)
input string
trimSpace bool
expected []string
}{
{
name: "splitWithParser - basic split",
splitFunc: pg_query.SplitWithParser,
input: "select * from a;select * from b;",
trimSpace: true,
expected: []string{
"select * from a",
"select * from b",
},
},
{
name: "splitWithParser - procedure",
splitFunc: pg_query.SplitWithParser,
input: splitTestInput,
trimSpace: true,
expected: []string{
splitExpected1,
splitExpected2,
},
},
{
name: "splitWithParser - basic split, no trim",
splitFunc: pg_query.SplitWithParser,
input: " select * from a;select * from b;",
trimSpace: false,
expected: []string{
" select * from a",
"select * from b",
},
},
{
name: "splitWithScanner - basic split",
splitFunc: pg_query.SplitWithScanner,
input: "select * from a;select * from b;",
trimSpace: true,
expected: []string{
"select * from a",
"select * from b",
},
},
{
name: "splitWithScanner - procedure",
splitFunc: pg_query.SplitWithScanner,
input: splitTestInput,
trimSpace: true,
expected: []string{
splitExpected1,
splitExpected2,
},
},
{
name: "splitWithScanner - basic split, no trim",
splitFunc: pg_query.SplitWithScanner,
input: " select * from a;select * from b;",
trimSpace: false,
expected: []string{
" select * from a",
"select * from b",
},
},
}

var (
splitTestInput = `UPDATE client SET name = "John Doe" WHERE id = 1;
CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;
`
splitExpected1 = `UPDATE client SET name = "John Doe" WHERE id = 1`
splitExpected2 = `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql`
)

func TestSplit(t *testing.T) {
for _, test := range splitTests {
t.Run(test.name, func(t *testing.T) {
actuals, err := test.splitFunc(test.input, test.trimSpace)
if err != nil {
t.Error(err)
}
if len(actuals) != len(test.expected) {
t.Error("unexpected number of results")
}
for i, actual := range actuals {
if actual != test.expected[i] {
t.Errorf("expected: [%s], actual: [%s]", test.expected[i], actual)
}
}
})
}
}

0 comments on commit 3c8cb1b

Please sign in to comment.