diff --git a/src/query.rs b/src/query.rs index daa7ffe..9c18edc 100644 --- a/src/query.rs +++ b/src/query.rs @@ -280,3 +280,55 @@ pub fn split_with_scanner(query: &str) -> Result> { unsafe { pg_query_free_split_result(result) }; split_result } + +pub fn unwrap_immediately_invoked_function(query: &str) -> Option { + let stmts = parse(query).ok()?.protobuf.stmts; + let stmts: Vec<_> = stmts.iter().filter_map(|s| s.stmt.as_ref().and_then(|s| s.node.as_ref())).collect(); + if stmts.len() != 3 { + return None; + } + use crate::NodeEnum::*; + let CreateFunctionStmt(create_fn) = stmts.get(0)? else { return None }; + let SelectStmt(select) = stmts.get(1)? else { return None }; + let DropStmt(drop) = stmts.get(2)? else { return None }; + let ObjectWithArgs(drop_name) = drop.objects.get(0)?.node.as_ref()? else { return None }; + let call_fn = { + let from_clause: Vec<_> = select.from_clause.iter().filter_map(|s| s.node.as_ref()).collect(); + if from_clause.len() != 1 { + return None; + } + let RangeFunction(function_call) = from_clause.get(0)? else { return None }; + if function_call.functions.len() != 1 { + return None; + } + let List(list) = function_call.functions[0].node.as_ref()? else { return None }; + let items: Vec<_> = list.items.iter().filter_map(|s| s.node.as_ref()).collect(); // there's an empty node here? + if items.len() != 1 { + return None; + } + let FuncCall(call) = list.items[0].node.as_ref()? else { return None }; + call + }; + if create_fn.funcname != call_fn.funcname || create_fn.funcname != drop_name.objname { + return None; + } + // Now that we know the query uses an immediately-invoked function that is later dropped, return only + // the query text inside the function (since the function makes it harder to read the real query). + let result = parse_plpgsql(query).unwrap(); + let functions = result.as_array()?; + if functions.len() != 1 { + return None; + } + let function = &functions[0]; + let body = function.get("PLpgSQL_function")?.get("action")?.get("PLpgSQL_stmt_block")?.get("body")?.as_array()?; + let blocks: Vec<_> = body.iter().filter(|b| b.get("PLpgSQL_stmt_block").is_some()).collect(); + if blocks.len() != 1 { + return None; + } + let blocks = blocks[0].get("PLpgSQL_stmt_block")?.get("body")?.as_array()?; + if blocks.len() != 1 { + return None; + } + let query = blocks[0].get("PLpgSQL_stmt_execsql")?.get("sqlstmt")?.get("PLpgSQL_expr")?.get("query")?; + Some(query.as_str()?.to_string()) +} diff --git a/tests/normalize_tests.rs b/tests/normalize_tests.rs index 89f70d6..4a65c6b 100644 --- a/tests/normalize_tests.rs +++ b/tests/normalize_tests.rs @@ -79,7 +79,34 @@ fn it_normalizes_EXPLAIN() { } #[test] -fn it_normalizes_DECLARE_CURSON() { +fn it_normalizes_DECLARE_CURSOR() { let result = normalize("DECLARE cursor_b CURSOR FOR SELECT * FROM databases WHERE id = 23").unwrap(); assert_eq!(result, "DECLARE cursor_b CURSOR FOR SELECT * FROM databases WHERE id = $1"); } + +#[test] +fn it_unwraps_and_normalizes_immediately_invoked_functions() { + let query = " + CREATE OR REPLACE FUNCTION pg_temp.testfunc(OUT response t, OUT sequelize_caught_exception text) + RETURNS RECORD AS $func_08a0ae3001ba4697bd3a1a677c6dab12$ + BEGIN + INSERT INTO t (columns) + VALUES ('non-normalized-values-here') + RETURNING * INTO response; + EXCEPTION WHEN unique_violation THEN GET STACKED DIAGNOSTICS sequelize_caught_exception = PG_EXCEPTION_DETAIL; + END $func_08a0ae3001ba4697bd3a1a677c6dab12$ LANGUAGE plpgsql; + + SELECT (testfunc.response).*, testfunc.sequelize_caught_exception FROM pg_temp.testfunc(); + + DROP FUNCTION IF EXISTS pg_temp.testfunc() + "; + let result = pg_query::unwrap_immediately_invoked_function(query).unwrap(); + let normalized_result = pg_query::normalize(&result).unwrap(); + assert_eq!( + normalized_result, + "INSERT INTO t (columns) + VALUES ($1) + RETURNING *" + ); + pg_query::parse(&normalized_result).unwrap(); +}