Skip to content

Commit

Permalink
Add .parse_plpgsql method to parse PL/pgSQL function definitions
Browse files Browse the repository at this point in the history
This uses Postgres' PL/pgSQL parser (as extracted in libpg_query)
to parse a PL/pgSQL CREATE FUNCTION statement into the AST.
  • Loading branch information
lfittl committed May 17, 2024
1 parent f23f1df commit 661c815
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
37 changes: 37 additions & 0 deletions ext/pg_query/pg_query_ruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ void raise_ruby_parse_error(PgQueryProtobufParseResult result);
void raise_ruby_normalize_error(PgQueryNormalizeResult result);
void raise_ruby_fingerprint_error(PgQueryFingerprintResult result);
void raise_ruby_scan_error(PgQueryScanResult result);
void raise_ruby_plpgsql_parse_error(PgQueryPlpgsqlParseResult result);

VALUE pg_query_ruby_parse_protobuf(VALUE self, VALUE input);
VALUE pg_query_ruby_deparse_protobuf(VALUE self, VALUE input);
VALUE pg_query_ruby_normalize(VALUE self, VALUE input);
VALUE pg_query_ruby_fingerprint(VALUE self, VALUE input);
VALUE pg_query_ruby_parse_plpgsql(VALUE self, VALUE input);
VALUE pg_query_ruby_scan(VALUE self, VALUE input);
VALUE pg_query_ruby_hash_xxh3_64(VALUE self, VALUE input, VALUE seed);

Expand All @@ -24,6 +26,7 @@ __attribute__((visibility ("default"))) void Init_pg_query(void)
rb_define_singleton_method(cPgQuery, "deparse_protobuf", pg_query_ruby_deparse_protobuf, 1);
rb_define_singleton_method(cPgQuery, "normalize", pg_query_ruby_normalize, 1);
rb_define_singleton_method(cPgQuery, "fingerprint", pg_query_ruby_fingerprint, 1);
rb_define_singleton_method(cPgQuery, "_raw_parse_plpgsql", pg_query_ruby_parse_plpgsql, 1);
rb_define_singleton_method(cPgQuery, "_raw_scan", pg_query_ruby_scan, 1);
rb_define_singleton_method(cPgQuery, "hash_xxh3_64", pg_query_ruby_hash_xxh3_64, 2);
rb_define_const(cPgQuery, "PG_VERSION", rb_str_new2(PG_VERSION));
Expand Down Expand Up @@ -121,6 +124,24 @@ void raise_ruby_scan_error(PgQueryScanResult result)
rb_exc_raise(rb_class_new_instance(4, args, cScanError));
}

void raise_ruby_plpgsql_parse_error(PgQueryPlpgsqlParseResult result)
{
VALUE cPgQuery, cPlpgsqlParseError;
VALUE args[4];

cPgQuery = rb_const_get(rb_cObject, rb_intern("PgQuery"));
cPlpgsqlParseError = rb_const_get_at(cPgQuery, rb_intern("PlpgsqlParseError"));

args[0] = rb_str_new2(result.error->message);
args[1] = rb_str_new2(result.error->filename);
args[2] = INT2NUM(result.error->lineno);
args[3] = INT2NUM(result.error->cursorpos);

pg_query_free_plpgsql_parse_result(result);

rb_exc_raise(rb_class_new_instance(4, args, cPlpgsqlParseError));
}

VALUE pg_query_ruby_parse_protobuf(VALUE self, VALUE input)
{
Check_Type(input, T_STRING);
Expand Down Expand Up @@ -197,6 +218,22 @@ VALUE pg_query_ruby_fingerprint(VALUE self, VALUE input)
return output;
}

VALUE pg_query_ruby_parse_plpgsql(VALUE self, VALUE input)
{
Check_Type(input, T_STRING);

VALUE output;
PgQueryPlpgsqlParseResult result = pg_query_parse_plpgsql(StringValueCStr(input));

if (result.error) raise_ruby_plpgsql_parse_error(result);

output = rb_str_new2(result.plpgsql_funcs);

pg_query_free_plpgsql_parse_result(result);

return output;
}

VALUE pg_query_ruby_scan(VALUE self, VALUE input)
{
Check_Type(input, T_STRING);
Expand Down
2 changes: 2 additions & 0 deletions lib/pg_query.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@
require 'pg_query/deparse'
require 'pg_query/truncate'

require 'pg_query/parse_plpgsql'

require 'pg_query/scan'
43 changes: 43 additions & 0 deletions lib/pg_query/parse_plpgsql.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
require 'json'
module PgQuery
class PlpgsqlParseError < ArgumentError
attr_reader :location
def initialize(message, source_file, source_line, location)
super("#{message} (#{source_file}:#{source_line})")
@location = location
end
end

def self.parse_plpgsql(input)
PlpgsqlParserResult.new(input, JSON.parse(_raw_parse_plpgsql(input)))
end

class PlpgsqlParserResult
attr_reader :input
attr_reader :tree

def initialize(input, tree)
@input = input
@tree = tree
end

def walk!
nodes = [tree.dup]
loop do
parent_node = nodes.shift
if parent_node.is_a?(Array)
parent_node.each do |node|
yield(node)
nodes << node
end
elsif parent_node.is_a?(Hash)
parent_node.each do |k, node|
yield(node)
nodes << node
end
end
break if nodes.empty?
end
end
end
end

0 comments on commit 661c815

Please sign in to comment.