diff --git a/ext/pg_query/pg_query_ruby.c b/ext/pg_query/pg_query_ruby.c index 72ac6fed..ff45fbde 100644 --- a/ext/pg_query/pg_query_ruby.c +++ b/ext/pg_query/pg_query_ruby.c @@ -6,11 +6,13 @@ void raise_ruby_parse_error(PgQueryProtobufParseResult result); void raise_ruby_normalize_error(PgQueryNormalizeResult result); void raise_ruby_fingerprint_error(PgQueryFingerprintResult result); void raise_ruby_scan_error(PgQueryScanResult result); +void raise_ruby_plpgsql_parse_error(PgQueryPlpgsqlParseResult result); VALUE pg_query_ruby_parse_protobuf(VALUE self, VALUE input); VALUE pg_query_ruby_deparse_protobuf(VALUE self, VALUE input); VALUE pg_query_ruby_normalize(VALUE self, VALUE input); VALUE pg_query_ruby_fingerprint(VALUE self, VALUE input); +VALUE pg_query_ruby_parse_plpgsql(VALUE self, VALUE input); VALUE pg_query_ruby_scan(VALUE self, VALUE input); VALUE pg_query_ruby_hash_xxh3_64(VALUE self, VALUE input, VALUE seed); @@ -24,6 +26,7 @@ __attribute__((visibility ("default"))) void Init_pg_query(void) rb_define_singleton_method(cPgQuery, "deparse_protobuf", pg_query_ruby_deparse_protobuf, 1); rb_define_singleton_method(cPgQuery, "normalize", pg_query_ruby_normalize, 1); rb_define_singleton_method(cPgQuery, "fingerprint", pg_query_ruby_fingerprint, 1); + rb_define_singleton_method(cPgQuery, "_raw_parse_plpgsql", pg_query_ruby_parse_plpgsql, 1); rb_define_singleton_method(cPgQuery, "_raw_scan", pg_query_ruby_scan, 1); rb_define_singleton_method(cPgQuery, "hash_xxh3_64", pg_query_ruby_hash_xxh3_64, 2); rb_define_const(cPgQuery, "PG_VERSION", rb_str_new2(PG_VERSION)); @@ -121,6 +124,24 @@ void raise_ruby_scan_error(PgQueryScanResult result) rb_exc_raise(rb_class_new_instance(4, args, cScanError)); } +void raise_ruby_plpgsql_parse_error(PgQueryPlpgsqlParseResult result) +{ + VALUE cPgQuery, cPlpgsqlParseError; + VALUE args[4]; + + cPgQuery = rb_const_get(rb_cObject, rb_intern("PgQuery")); + cPlpgsqlParseError = rb_const_get_at(cPgQuery, rb_intern("PlpgsqlParseError")); + + args[0] = rb_str_new2(result.error->message); + args[1] = rb_str_new2(result.error->filename); + args[2] = INT2NUM(result.error->lineno); + args[3] = INT2NUM(result.error->cursorpos); + + pg_query_free_plpgsql_parse_result(result); + + rb_exc_raise(rb_class_new_instance(4, args, cPlpgsqlParseError)); +} + VALUE pg_query_ruby_parse_protobuf(VALUE self, VALUE input) { Check_Type(input, T_STRING); @@ -197,6 +218,22 @@ VALUE pg_query_ruby_fingerprint(VALUE self, VALUE input) return output; } +VALUE pg_query_ruby_parse_plpgsql(VALUE self, VALUE input) +{ + Check_Type(input, T_STRING); + + VALUE output; + PgQueryPlpgsqlParseResult result = pg_query_parse_plpgsql(StringValueCStr(input)); + + if (result.error) raise_ruby_plpgsql_parse_error(result); + + output = rb_str_new2(result.plpgsql_funcs); + + pg_query_free_plpgsql_parse_result(result); + + return output; +} + VALUE pg_query_ruby_scan(VALUE self, VALUE input) { Check_Type(input, T_STRING); diff --git a/lib/pg_query.rb b/lib/pg_query.rb index 4d94b58b..4e568903 100644 --- a/lib/pg_query.rb +++ b/lib/pg_query.rb @@ -15,4 +15,6 @@ require 'pg_query/deparse' require 'pg_query/truncate' +require 'pg_query/parse_plpgsql' + require 'pg_query/scan' diff --git a/lib/pg_query/parse_plpgsql.rb b/lib/pg_query/parse_plpgsql.rb new file mode 100644 index 00000000..bfe73baa --- /dev/null +++ b/lib/pg_query/parse_plpgsql.rb @@ -0,0 +1,43 @@ +require 'json' +module PgQuery + class PlpgsqlParseError < ArgumentError + attr_reader :location + def initialize(message, source_file, source_line, location) + super("#{message} (#{source_file}:#{source_line})") + @location = location + end + end + + def self.parse_plpgsql(input) + PlpgsqlParserResult.new(input, JSON.parse(_raw_parse_plpgsql(input))) + end + + class PlpgsqlParserResult + attr_reader :input + attr_reader :tree + + def initialize(input, tree) + @input = input + @tree = tree + end + + def walk! + nodes = [tree.dup] + loop do + parent_node = nodes.shift + if parent_node.is_a?(Array) + parent_node.each do |node| + yield(node) + nodes << node + end + elsif parent_node.is_a?(Hash) + parent_node.each do |k, node| + yield(node) + nodes << node + end + end + break if nodes.empty? + end + end + end +end