Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize and unwrap immediately invoked functions #19

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,55 @@ pub fn split_with_scanner(query: &str) -> Result<Vec<&str>> {
unsafe { pg_query_free_split_result(result) };
split_result
}

pub fn unwrap_immediately_invoked_function(query: &str) -> Option<String> {
let stmts = parse(query).ok()?.protobuf.stmts;
let stmts: Vec<_> = stmts.iter().filter_map(|s| s.stmt.as_ref().and_then(|s| s.node.as_ref())).collect();
if stmts.len() != 3 {
return None;
}
use crate::NodeEnum::*;
let CreateFunctionStmt(create_fn) = stmts.get(0)? else { return None };
let SelectStmt(select) = stmts.get(1)? else { return None };
let DropStmt(drop) = stmts.get(2)? else { return None };
let ObjectWithArgs(drop_name) = drop.objects.get(0)?.node.as_ref()? else { return None };
let call_fn = {
let from_clause: Vec<_> = select.from_clause.iter().filter_map(|s| s.node.as_ref()).collect();
if from_clause.len() != 1 {
return None;
}
let RangeFunction(function_call) = from_clause.get(0)? else { return None };
if function_call.functions.len() != 1 {
return None;
}
let List(list) = function_call.functions[0].node.as_ref()? else { return None };
let items: Vec<_> = list.items.iter().filter_map(|s| s.node.as_ref()).collect(); // there's an empty node here?
if items.len() != 1 {
return None;
}
let FuncCall(call) = list.items[0].node.as_ref()? else { return None };
call
};
if create_fn.funcname != call_fn.funcname || create_fn.funcname != drop_name.objname {
return None;
}
// Now that we know the query uses an immediately-invoked function that is later dropped, return only
// the query text inside the function (since the function makes it harder to read the real query).
let result = parse_plpgsql(query).unwrap();
let functions = result.as_array()?;
if functions.len() != 1 {
return None;
}
let function = &functions[0];
let body = function.get("PLpgSQL_function")?.get("action")?.get("PLpgSQL_stmt_block")?.get("body")?.as_array()?;
let blocks: Vec<_> = body.iter().filter(|b| b.get("PLpgSQL_stmt_block").is_some()).collect();
if blocks.len() != 1 {
return None;
}
let blocks = blocks[0].get("PLpgSQL_stmt_block")?.get("body")?.as_array()?;
if blocks.len() != 1 {
return None;
}
let query = blocks[0].get("PLpgSQL_stmt_execsql")?.get("sqlstmt")?.get("PLpgSQL_expr")?.get("query")?;
Some(query.as_str()?.to_string())
}
29 changes: 28 additions & 1 deletion tests/normalize_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,34 @@ fn it_normalizes_EXPLAIN() {
}

#[test]
fn it_normalizes_DECLARE_CURSON() {
fn it_normalizes_DECLARE_CURSOR() {
let result = normalize("DECLARE cursor_b CURSOR FOR SELECT * FROM databases WHERE id = 23").unwrap();
assert_eq!(result, "DECLARE cursor_b CURSOR FOR SELECT * FROM databases WHERE id = $1");
}

#[test]
fn it_unwraps_and_normalizes_immediately_invoked_functions() {
let query = "
CREATE OR REPLACE FUNCTION pg_temp.testfunc(OUT response t, OUT sequelize_caught_exception text)
RETURNS RECORD AS $func_08a0ae3001ba4697bd3a1a677c6dab12$
BEGIN
INSERT INTO t (columns)
VALUES ('non-normalized-values-here')
RETURNING * INTO response;
EXCEPTION WHEN unique_violation THEN GET STACKED DIAGNOSTICS sequelize_caught_exception = PG_EXCEPTION_DETAIL;
END $func_08a0ae3001ba4697bd3a1a677c6dab12$ LANGUAGE plpgsql;

SELECT (testfunc.response).*, testfunc.sequelize_caught_exception FROM pg_temp.testfunc();

DROP FUNCTION IF EXISTS pg_temp.testfunc()
";
let result = pg_query::unwrap_immediately_invoked_function(query).unwrap();
let normalized_result = pg_query::normalize(&result).unwrap();
assert_eq!(
normalized_result,
"INSERT INTO t (columns)
VALUES ($1)
RETURNING *"
);
pg_query::parse(&normalized_result).unwrap();
}