From c3a818d346a927c18469460bb18acb397f4f4301 Mon Sep 17 00:00:00 2001 From: Sean Linsley Date: Thu, 29 Aug 2024 13:22:08 -0500 Subject: [PATCH] Add function to normalize only utility statements (#116) --- Makefile | 2 +- normalize_test.go | 26 ++++++++++++++++++++++ parser/include/pg_query.h | 1 + parser/parser.go | 18 ++++++++++++++++ parser/pg_query_normalize.c | 43 ++++++++++++++++++++++++++++++++++++- parser/pg_query_scan.c | 2 +- pg_query.go | 5 +++++ 7 files changed, 94 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 7bf2582d..5ffdee51 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ benchmark: # --- Below only needed for releasing new versions -LIB_PG_QUERY_TAG = 16-5.1.0 +LIB_PG_QUERY_TAG = 43bad3cbcd1a70a30494b64f464c3f60579884ed root_dir := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) LIB_TMPDIR = $(root_dir)/tmp diff --git a/normalize_test.go b/normalize_test.go index 1bdd3731..9de758c2 100644 --- a/normalize_test.go +++ b/normalize_test.go @@ -68,3 +68,29 @@ func TestNormalizeError(t *testing.T) { } } } + +var normalizeUtilityTests = []struct { + input string + expected string +}{ + { + "SELECT 1", + "SELECT 1", + }, + { + "CREATE ROLE postgres PASSWORD 'xyz'", + "CREATE ROLE postgres PASSWORD $1", + }, +} + +func TestNormalizeUtility(t *testing.T) { + for _, test := range normalizeUtilityTests { + actual, err := pg_query.NormalizeUtility(test.input) + + if err != nil { + t.Errorf("Normalize(%s)\nerror %s\n\n", test.input, err) + } else if !reflect.DeepEqual(actual, test.expected) { + t.Errorf("Normalize(%s)\nexpected %s\nactual %s\n\n", test.input, test.expected, actual) + } + } +} diff --git a/parser/include/pg_query.h b/parser/include/pg_query.h index c7d5701b..7f34c41b 100644 --- a/parser/include/pg_query.h +++ b/parser/include/pg_query.h @@ -96,6 +96,7 @@ extern "C" { #endif PgQueryNormalizeResult pg_query_normalize(const char* input); +PgQueryNormalizeResult pg_query_normalize_utility(const char* input); PgQueryScanResult pg_query_scan(const char* input); PgQueryParseResult pg_query_parse(const char* input); PgQueryParseResult pg_query_parse_opts(const char* input, int parser_options); diff --git a/parser/parser.go b/parser/parser.go index c866a124..e2da4ff3 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -175,6 +175,24 @@ func Normalize(input string) (result string, err error) { return } +// Normalize the passed utility statement to replace constant values with ? characters +func NormalizeUtility(input string) (result string, err error) { + inputC := C.CString(input) + defer C.free(unsafe.Pointer(inputC)) + + resultC := C.pg_query_normalize_utility(inputC) + defer C.pg_query_free_normalize_result(resultC) + + if resultC.error != nil { + err = newPgQueryError(resultC.error) + return + } + + result = C.GoString(resultC.normalized_query) + + return +} + func SplitWithScanner(input string, trimSpace bool) (result []string, err error) { inputC := C.CString(input) defer C.free(unsafe.Pointer(inputC)) diff --git a/parser/pg_query_normalize.c b/parser/pg_query_normalize.c index bc15b01e..460493fd 100644 --- a/parser/pg_query_normalize.c +++ b/parser/pg_query_normalize.c @@ -48,6 +48,9 @@ typedef struct pgssConstLocations int *param_refs; int param_refs_buf_size; int param_refs_count; + + /* Should only utility statements be normalized? Set by pg_query_normalize_utility */ + bool normalize_utility_only; } pgssConstLocations; /* @@ -398,8 +401,10 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate) case T_RawStmt: return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate); case T_VariableSetStmt: + if (jstate->normalize_utility_only) return false; return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate); case T_CopyStmt: + if (jstate->normalize_utility_only) return false; return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate); case T_ExplainStmt: return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate); @@ -408,10 +413,13 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate) case T_AlterRoleStmt: return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate); case T_DeclareCursorStmt: + if (jstate->normalize_utility_only) return false; return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate); case T_CreateFunctionStmt: + if (jstate->normalize_utility_only) return false; return const_record_walker((Node *) ((CreateFunctionStmt *) node)->options, jstate); case T_DoStmt: + if (jstate->normalize_utility_only) return false; return const_record_walker((Node *) ((DoStmt *) node)->args, jstate); case T_CreateSubscriptionStmt: record_matching_string(jstate, ((CreateSubscriptionStmt *) node)->conninfo); @@ -428,6 +436,7 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate) return false; case T_SelectStmt: { + if (jstate->normalize_utility_only) return false; SelectStmt *stmt = (SelectStmt *) node; ListCell *lc; List *fp_and_param_refs_list = NIL; @@ -540,6 +549,26 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate) return false; } + case T_MergeStmt: + { + if (jstate->normalize_utility_only) return false; + return raw_expression_tree_walker(node, const_record_walker, (void*) jstate); + } + case T_InsertStmt: + { + if (jstate->normalize_utility_only) return false; + return raw_expression_tree_walker(node, const_record_walker, (void*) jstate); + } + case T_UpdateStmt: + { + if (jstate->normalize_utility_only) return false; + return raw_expression_tree_walker(node, const_record_walker, (void*) jstate); + } + case T_DeleteStmt: + { + if (jstate->normalize_utility_only) return false; + return raw_expression_tree_walker(node, const_record_walker, (void*) jstate); + } default: { PG_TRY(); @@ -558,7 +587,7 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate) return false; } -PgQueryNormalizeResult pg_query_normalize(const char* input) +PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_utility_only) { MemoryContext ctx = NULL; PgQueryNormalizeResult result = {0}; @@ -588,6 +617,7 @@ PgQueryNormalizeResult pg_query_normalize(const char* input) jstate.param_refs = NULL; jstate.param_refs_buf_size = 0; jstate.param_refs_count = 0; + jstate.normalize_utility_only = normalize_utility_only; /* Walk tree and record const locations */ const_record_walker((Node *) tree, &jstate); @@ -621,6 +651,17 @@ PgQueryNormalizeResult pg_query_normalize(const char* input) return result; } +PgQueryNormalizeResult pg_query_normalize(const char* input) +{ + return pg_query_normalize_ext(input, false); +} + + +PgQueryNormalizeResult pg_query_normalize_utility(const char* input) +{ + return pg_query_normalize_ext(input, true); +} + void pg_query_free_normalize_result(PgQueryNormalizeResult result) { if (result.error) { diff --git a/parser/pg_query_scan.c b/parser/pg_query_scan.c index 0a53d4f7..1d0e052e 100644 --- a/parser/pg_query_scan.c +++ b/parser/pg_query_scan.c @@ -93,7 +93,7 @@ PgQueryScanResult pg_query_scan(const char* input) output_tokens[i] = malloc(sizeof(PgQuery__ScanToken)); pg_query__scan_token__init(output_tokens[i]); output_tokens[i]->start = yylloc; - if (tok == SCONST || tok == BCONST || tok == XCONST || tok == IDENT || tok == C_COMMENT) { + if (tok == SCONST || tok == USCONST || tok == BCONST || tok == XCONST || tok == IDENT || tok == UIDENT || tok == C_COMMENT) { output_tokens[i]->end = yyextra.yyllocend; } else { output_tokens[i]->end = yylloc + ((struct yyguts_t*) yyscanner)->yyleng_r; diff --git a/pg_query.go b/pg_query.go index fc9e6372..996ea047 100644 --- a/pg_query.go +++ b/pg_query.go @@ -57,6 +57,11 @@ func Normalize(input string) (result string, err error) { return parser.Normalize(input) } +// Normalize the passed utility statement to replace constant values with $n parameter references +func NormalizeUtility(input string) (result string, err error) { + return parser.NormalizeUtility(input) +} + // Fingerprint - Fingerprint the passed SQL statement to a hex string func Fingerprint(input string) (result string, err error) { return parser.FingerprintToHexStr(input)