Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to normalize only utility statements #116

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ benchmark:

# --- Below only needed for releasing new versions

LIB_PG_QUERY_TAG = 16-5.1.0
LIB_PG_QUERY_TAG = 43bad3cbcd1a70a30494b64f464c3f60579884ed

root_dir := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
LIB_TMPDIR = $(root_dir)/tmp
Expand Down
26 changes: 26 additions & 0 deletions normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,29 @@ func TestNormalizeError(t *testing.T) {
}
}
}

var normalizeUtilityTests = []struct {
input string
expected string
}{
{
"SELECT 1",
"SELECT 1",
},
{
"CREATE ROLE postgres PASSWORD 'xyz'",
"CREATE ROLE postgres PASSWORD $1",
},
}

func TestNormalizeUtility(t *testing.T) {
for _, test := range normalizeUtilityTests {
actual, err := pg_query.NormalizeUtility(test.input)

if err != nil {
t.Errorf("Normalize(%s)\nerror %s\n\n", test.input, err)
} else if !reflect.DeepEqual(actual, test.expected) {
t.Errorf("Normalize(%s)\nexpected %s\nactual %s\n\n", test.input, test.expected, actual)
}
}
}
1 change: 1 addition & 0 deletions parser/include/pg_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ extern "C" {
#endif

PgQueryNormalizeResult pg_query_normalize(const char* input);
PgQueryNormalizeResult pg_query_normalize_utility(const char* input);
PgQueryScanResult pg_query_scan(const char* input);
PgQueryParseResult pg_query_parse(const char* input);
PgQueryParseResult pg_query_parse_opts(const char* input, int parser_options);
Expand Down
18 changes: 18 additions & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,24 @@ func Normalize(input string) (result string, err error) {
return
}

// Normalize the passed utility statement to replace constant values with ? characters
func NormalizeUtility(input string) (result string, err error) {
inputC := C.CString(input)
defer C.free(unsafe.Pointer(inputC))

resultC := C.pg_query_normalize_utility(inputC)
defer C.pg_query_free_normalize_result(resultC)

if resultC.error != nil {
err = newPgQueryError(resultC.error)
return
}

result = C.GoString(resultC.normalized_query)

return
}

func SplitWithScanner(input string, trimSpace bool) (result []string, err error) {
inputC := C.CString(input)
defer C.free(unsafe.Pointer(inputC))
Expand Down
43 changes: 42 additions & 1 deletion parser/pg_query_normalize.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ typedef struct pgssConstLocations
int *param_refs;
int param_refs_buf_size;
int param_refs_count;

/* Should only utility statements be normalized? Set by pg_query_normalize_utility */
bool normalize_utility_only;
} pgssConstLocations;

/*
Expand Down Expand Up @@ -398,8 +401,10 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)
case T_RawStmt:
return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate);
case T_VariableSetStmt:
if (jstate->normalize_utility_only) return false;
return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate);
case T_CopyStmt:
if (jstate->normalize_utility_only) return false;
return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate);
case T_ExplainStmt:
return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate);
Expand All @@ -408,10 +413,13 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)
case T_AlterRoleStmt:
return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate);
case T_DeclareCursorStmt:
if (jstate->normalize_utility_only) return false;
return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate);
case T_CreateFunctionStmt:
if (jstate->normalize_utility_only) return false;
return const_record_walker((Node *) ((CreateFunctionStmt *) node)->options, jstate);
case T_DoStmt:
if (jstate->normalize_utility_only) return false;
return const_record_walker((Node *) ((DoStmt *) node)->args, jstate);
case T_CreateSubscriptionStmt:
record_matching_string(jstate, ((CreateSubscriptionStmt *) node)->conninfo);
Expand All @@ -428,6 +436,7 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)
return false;
case T_SelectStmt:
{
if (jstate->normalize_utility_only) return false;
SelectStmt *stmt = (SelectStmt *) node;
ListCell *lc;
List *fp_and_param_refs_list = NIL;
Expand Down Expand Up @@ -540,6 +549,26 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)

return false;
}
case T_MergeStmt:
{
if (jstate->normalize_utility_only) return false;
return raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
}
case T_InsertStmt:
{
if (jstate->normalize_utility_only) return false;
return raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
}
case T_UpdateStmt:
{
if (jstate->normalize_utility_only) return false;
return raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
}
case T_DeleteStmt:
{
if (jstate->normalize_utility_only) return false;
return raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
}
default:
{
PG_TRY();
Expand All @@ -558,7 +587,7 @@ static bool const_record_walker(Node *node, pgssConstLocations *jstate)
return false;
}

PgQueryNormalizeResult pg_query_normalize(const char* input)
PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_utility_only)
{
MemoryContext ctx = NULL;
PgQueryNormalizeResult result = {0};
Expand Down Expand Up @@ -588,6 +617,7 @@ PgQueryNormalizeResult pg_query_normalize(const char* input)
jstate.param_refs = NULL;
jstate.param_refs_buf_size = 0;
jstate.param_refs_count = 0;
jstate.normalize_utility_only = normalize_utility_only;

/* Walk tree and record const locations */
const_record_walker((Node *) tree, &jstate);
Expand Down Expand Up @@ -621,6 +651,17 @@ PgQueryNormalizeResult pg_query_normalize(const char* input)
return result;
}

PgQueryNormalizeResult pg_query_normalize(const char* input)
{
return pg_query_normalize_ext(input, false);
}


PgQueryNormalizeResult pg_query_normalize_utility(const char* input)
{
return pg_query_normalize_ext(input, true);
}

void pg_query_free_normalize_result(PgQueryNormalizeResult result)
{
if (result.error) {
Expand Down
2 changes: 1 addition & 1 deletion parser/pg_query_scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ PgQueryScanResult pg_query_scan(const char* input)
output_tokens[i] = malloc(sizeof(PgQuery__ScanToken));
pg_query__scan_token__init(output_tokens[i]);
output_tokens[i]->start = yylloc;
if (tok == SCONST || tok == BCONST || tok == XCONST || tok == IDENT || tok == C_COMMENT) {
if (tok == SCONST || tok == USCONST || tok == BCONST || tok == XCONST || tok == IDENT || tok == UIDENT || tok == C_COMMENT) {
output_tokens[i]->end = yyextra.yyllocend;
} else {
output_tokens[i]->end = yylloc + ((struct yyguts_t*) yyscanner)->yyleng_r;
Expand Down
5 changes: 5 additions & 0 deletions pg_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ func Normalize(input string) (result string, err error) {
return parser.Normalize(input)
}

// Normalize the passed utility statement to replace constant values with $n parameter references
func NormalizeUtility(input string) (result string, err error) {
return parser.NormalizeUtility(input)
}

// Fingerprint - Fingerprint the passed SQL statement to a hex string
func Fingerprint(input string) (result string, err error) {
return parser.FingerprintToHexStr(input)
Expand Down
Loading