From 0b5251f5ac0768027d900a6d5f3eff61398b843b Mon Sep 17 00:00:00 2001 From: Mark Abspoel Date: Fri, 5 Jul 2024 11:00:30 +0200 Subject: [PATCH] Add iter() on ParseResult and NodeRef This allows walking over all nodes in the AST, instead of just a limited subset as in the current `nodes()` function. See https://github.com/pganalyze/pg_query.rs/issues/31 The implementation uses static code generation in `build.rs`. The protobuf definitions are parsed, and a graph of all Message types is constructed. All NodeRef types are given an `unpack()` function, that recursively calls `unpack()` on all relevant fields (i.e., the fields that have a Node type, or that have a type that eventually has a Node type as a field). The result is guaranteed to visit all nodes. The code generation mechanism is maybe also useful to replace parts of the codebase that currently need to be manually hardcoded. Adds prost, prost-types and heck to the build dependencies, and updates the prost dependency version. --- Cargo.lock | 60 +++++++-- Cargo.toml | 5 +- build.rs | 128 +++++++++++++++++++ src/bindings.rs | 1 + src/node_enum.rs | 1 + src/node_ref.rs | 297 +++++++++++++++++++++++++++++++++++++++++++ tests/parse_tests.rs | 49 ++++++- 7 files changed, 521 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a3d1ca..cae3e6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,7 +54,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.31", + "syn 2.0.32", "which", ] @@ -368,9 +368,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "peeking_take_while" @@ -398,10 +398,12 @@ dependencies = [ "easy-parallel", "fs_extra", "glob", + "heck", "itertools", "pretty_assertions", - "prost", + "prost 0.13.1", "prost-build", + "prost-types 0.13.1", "regex", "serde", "serde_json", @@ -425,7 +427,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8832c0f9be7e3cae60727e6256cfd2cd3c3e2b6cd5dad4190ecb2fd658c9030b" dependencies = [ "proc-macro2", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -444,7 +446,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71adf41db68aa0daaefc69bb30bcd68ded9b9abaad5d1fbb6304c4fb390e083e" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.10.1", +] + +[[package]] +name = "prost" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc" +dependencies = [ + "bytes", + "prost-derive 0.13.1", ] [[package]] @@ -462,8 +474,8 @@ dependencies = [ "log", "multimap", "petgraph", - "prost", - "prost-types", + "prost 0.10.4", + "prost-types 0.10.1", "regex", "tempfile", "which", @@ -482,6 +494,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prost-derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "prost-types" version = "0.10.1" @@ -489,7 +514,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d0a014229361011dc8e69c8a1ec6c2e8d0f2af7c91e3ea3f5b2170298461e68" dependencies = [ "bytes", - "prost", + "prost 0.10.4", +] + +[[package]] +name = "prost-types" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2" +dependencies = [ + "prost 0.13.1", ] [[package]] @@ -610,7 +644,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -643,9 +677,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.31" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -693,7 +727,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 23ae4d9..55eaabe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ repository = "https://github.com/pganalyze/pg_query.rs" [dependencies] itertools = "0.10.3" -prost = "0.10.4" +prost = "0.13.0" serde = { version = "1.0.139", features = ["derive"] } serde_json = "1.0.82" thiserror = "1.0.31" @@ -21,10 +21,13 @@ thiserror = "1.0.31" [build-dependencies] bindgen = "0.66.1" clippy = { version = "0.0.302", optional = true } +prost = "0.13.0" prost-build = "0.10.4" +prost-types = "0.13.0" fs_extra = "1.2.0" cc = "1.0.83" glob = "0.3.1" +heck = "0.4" [dev-dependencies] easy-parallel = "3.2.0" diff --git a/build.rs b/build.rs index 57d3664..d09fa35 100644 --- a/build.rs +++ b/build.rs @@ -3,12 +3,128 @@ use fs_extra::dir::CopyOptions; use glob::glob; +use heck::ToUpperCamelCase; +use prost::Message; +use prost_types::field_descriptor_proto::Type; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::env; use std::path::{Path, PathBuf}; static SOURCE_DIRECTORY: &str = "libpg_query"; static LIBRARY_NAME: &str = "pg_query"; +type Cardinality = prost_types::field_descriptor_proto::Label; + +struct Edge { + field: String, + message: usize, + cardinality: Cardinality, +} + +/// Represents a directed labeled multigraph of Message types. Each vertex represents a message +/// type. An edge A->B is a tuple (field_name: String, type: FieldType), that states that +/// Message A has a field (with name equal to `field_name`) of Message type B. +struct MessageGraph { + messages: HashMap, + + /// For each vertex A, the list of edges from A to other vertices, and a set of vertices B such that there is at least one edge B->A + edges: Vec<(String, Vec, BTreeSet)>, +} + +impl MessageGraph { + fn new() -> Self { + Self { messages: HashMap::new(), edges: Vec::new() } + } + + /// Get the ID for a given `type_name` if it exists, or generate a new one if it doesn't + fn id_for(&mut self, type_name: &str) -> usize { + if let Some(id) = self.messages.get(type_name) { + *id + } else { + let id = self.edges.len(); + self.edges.push((type_name.to_string(), Vec::new(), BTreeSet::new())); + self.messages.insert(type_name.to_string(), id); + id + } + } + + /// Parse protobuf files and populate the graph with its Messages and corresponding edges + fn make(&mut self, fds: prost_types::FileDescriptorSet) { + for fd in fds.file { + let package = fd.package().to_string(); + for msg in fd.message_type { + let full_name = format!(".{}.{}", package, msg.name()); + let id = self.id_for(&full_name); + + // We use this to check for duplicate fields + let mut fields: HashSet = HashSet::new(); + + if msg.name() != "Node" && msg.name() != "A_Const" { + for field in &msg.field { + if field.r#type() != Type::Message { + continue; + } + + if field.oneof_index.is_some() { + panic!("No support for enums: field {} of message {}", field.name(), msg.name()); + } + + if !fields.insert(field.name().to_string()) { + panic!("Duplicate field: {}", field.name()); + } + + let message_id = self.id_for(field.type_name()); + self.edges[id].1.push(Edge { field: field.name().to_string(), message: message_id, cardinality: field.label() }); + self.edges[message_id].2.insert(id); + } + } + } + } + } + + /// Set `filter[x] = true` for all vertices `x` with a path to vertex `id` + fn filter_incoming(&self, id: usize, filter: &mut Vec) { + if !filter[id] { + filter[id] = true; + for nb in self.edges[id].2.iter() { + self.filter_incoming(*nb, filter); + } + } + } + + /// Generate code for `unpack` impls for all Message types + fn write(&self, buf: &mut String) { + let mut filter = vec![false; self.messages.len()]; + self.filter_incoming(*self.messages.get(".pg_query.Node").unwrap(), &mut filter); + for (id, (name, edges, _incoming)) in self.edges.iter().enumerate() { + let filtered = filter[id]; + let short_name = &name[name.rfind(".").unwrap() + 1..].to_upper_camel_case(); + if short_name == "Node" || short_name == "ParseResult" || short_name == "ScanResult" || short_name == "ScanToken" { + continue; + } + + buf.push_str(&format!("impl<'a> Unpack<'a> for protobuf::{} {{\n", short_name)); + if filtered && edges.iter().any(|e| filter[e.message]) { + buf.push_str(" fn unpack(&'a self, vec: &mut VecDeque>) {\n"); + for edge in edges.iter() { + if filter[edge.message] { + match edge.cardinality { + Cardinality::Repeated => buf.push_str(&format!(" self.{}.iter().for_each(|n| n.unpack(vec));\n", edge.field)), + Cardinality::Required => buf.push_str(&format!(" vec.push_back(self.{});\n", edge.field)), + Cardinality::Optional => { + buf.push_str(&format!(" if let Some(ref e) = self.{} {{ e.unpack(vec); }}\n", edge.field)) + } + } + } + } + buf.push_str(" }\n}\n\n"); + } else { + buf.push_str(" fn unpack(&'a self, _vec: &mut VecDeque>) { }\n}\n"); + } + } + } +} + fn main() -> Result<(), Box> { let out_dir = PathBuf::from(env::var("OUT_DIR")?); let build_path = Path::new(".").join(SOURCE_DIRECTORY); @@ -65,6 +181,18 @@ fn main() -> Result<(), Box> { .write_to_file(out_dir.join("bindings.rs"))?; // Generate the protobuf definition + let mut config = prost_build::Config::new(); + let fds_path = out_dir.join("./file_descriptor_set.bin"); + config.file_descriptor_set_path(fds_path.clone()); + config.compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; + + let mut buf = String::new(); + let fds = prost_types::FileDescriptorSet::decode(std::fs::read(fds_path)?.as_slice())?; + let mut graph = MessageGraph::new(); + graph.make(fds); + graph.write(&mut buf); + std::fs::write(out_dir.join("./unpack.rs"), buf)?; + prost_build::compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; Ok(()) diff --git a/src/bindings.rs b/src/bindings.rs index 4f98483..6672e8a 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -3,4 +3,5 @@ #![allow(non_snake_case)] #![allow(unused)] #![allow(clippy::all)] + include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/src/node_enum.rs b/src/node_enum.rs index 665308c..2178182 100644 --- a/src/node_enum.rs +++ b/src/node_enum.rs @@ -25,6 +25,7 @@ impl NodeEnum { while !iter.is_empty() { let (node, depth, context, has_filter_columns) = iter.remove(0); let depth = depth + 1; + match node { // // The following statement types do not modify tables diff --git a/src/node_ref.rs b/src/node_ref.rs index 41a97d6..460068d 100644 --- a/src/node_ref.rs +++ b/src/node_ref.rs @@ -1,4 +1,5 @@ use crate::*; +use std::collections::VecDeque; #[derive(Debug, Copy, Clone)] pub enum NodeRef<'a> { @@ -256,6 +257,298 @@ pub enum NodeRef<'a> { RtepermissionInfo(&'a protobuf::RtePermissionInfo), } +/// Allows to walk through a tree of NodeRefs. +struct NodesIter<'a> { + /// A deque of NodeRefs that have still to be expanded + vec: std::collections::VecDeque>, +} + +impl<'a> Iterator for NodesIter<'a> { + type Item = NodeRef<'a>; + + fn next(&mut self) -> Option { + if let Some(node) = self.vec.pop_front() { + match node { + NodeRef::Alias(n) => n.unpack(&mut self.vec), + NodeRef::RangeVar(n) => n.unpack(&mut self.vec), + NodeRef::TableFunc(n) => n.unpack(&mut self.vec), + NodeRef::Var(n) => n.unpack(&mut self.vec), + NodeRef::Param(n) => n.unpack(&mut self.vec), + NodeRef::Aggref(n) => n.unpack(&mut self.vec), + NodeRef::GroupingFunc(n) => n.unpack(&mut self.vec), + NodeRef::WindowFunc(n) => n.unpack(&mut self.vec), + NodeRef::SubscriptingRef(n) => n.unpack(&mut self.vec), + NodeRef::FuncExpr(n) => n.unpack(&mut self.vec), + NodeRef::NamedArgExpr(n) => n.unpack(&mut self.vec), + NodeRef::OpExpr(n) => n.unpack(&mut self.vec), + NodeRef::DistinctExpr(n) => n.unpack(&mut self.vec), + NodeRef::NullIfExpr(n) => n.unpack(&mut self.vec), + NodeRef::ScalarArrayOpExpr(n) => n.unpack(&mut self.vec), + NodeRef::BoolExpr(n) => n.unpack(&mut self.vec), + NodeRef::SubLink(n) => n.unpack(&mut self.vec), + NodeRef::SubPlan(n) => n.unpack(&mut self.vec), + NodeRef::AlternativeSubPlan(n) => n.unpack(&mut self.vec), + NodeRef::FieldSelect(n) => n.unpack(&mut self.vec), + NodeRef::FieldStore(n) => n.unpack(&mut self.vec), + NodeRef::RelabelType(n) => n.unpack(&mut self.vec), + NodeRef::CoerceViaIo(n) => n.unpack(&mut self.vec), + NodeRef::ArrayCoerceExpr(n) => n.unpack(&mut self.vec), + NodeRef::ConvertRowtypeExpr(n) => n.unpack(&mut self.vec), + NodeRef::CollateExpr(n) => n.unpack(&mut self.vec), + NodeRef::CaseExpr(n) => n.unpack(&mut self.vec), + NodeRef::CaseWhen(n) => n.unpack(&mut self.vec), + NodeRef::CaseTestExpr(n) => n.unpack(&mut self.vec), + NodeRef::ArrayExpr(n) => n.unpack(&mut self.vec), + NodeRef::RowExpr(n) => n.unpack(&mut self.vec), + NodeRef::RowCompareExpr(n) => n.unpack(&mut self.vec), + NodeRef::CoalesceExpr(n) => n.unpack(&mut self.vec), + NodeRef::MinMaxExpr(n) => n.unpack(&mut self.vec), + NodeRef::SqlvalueFunction(n) => n.unpack(&mut self.vec), + NodeRef::XmlExpr(n) => n.unpack(&mut self.vec), + NodeRef::NullTest(n) => n.unpack(&mut self.vec), + NodeRef::BooleanTest(n) => n.unpack(&mut self.vec), + NodeRef::CoerceToDomain(n) => n.unpack(&mut self.vec), + NodeRef::CoerceToDomainValue(n) => n.unpack(&mut self.vec), + NodeRef::SetToDefault(n) => n.unpack(&mut self.vec), + NodeRef::CurrentOfExpr(n) => n.unpack(&mut self.vec), + NodeRef::NextValueExpr(n) => n.unpack(&mut self.vec), + NodeRef::InferenceElem(n) => n.unpack(&mut self.vec), + NodeRef::TargetEntry(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblRef(n) => n.unpack(&mut self.vec), + NodeRef::JoinExpr(n) => n.unpack(&mut self.vec), + NodeRef::FromExpr(n) => n.unpack(&mut self.vec), + NodeRef::OnConflictExpr(n) => n.unpack(&mut self.vec), + NodeRef::IntoClause(n) => n.unpack(&mut self.vec), + NodeRef::RawStmt(n) => n.unpack(&mut self.vec), + NodeRef::Query(n) => n.unpack(&mut self.vec), + NodeRef::InsertStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeleteStmt(n) => n.unpack(&mut self.vec), + NodeRef::UpdateStmt(n) => n.unpack(&mut self.vec), + NodeRef::SelectStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableCmd(n) => n.unpack(&mut self.vec), + NodeRef::AlterDomainStmt(n) => n.unpack(&mut self.vec), + NodeRef::SetOperationStmt(n) => n.unpack(&mut self.vec), + NodeRef::GrantStmt(n) => n.unpack(&mut self.vec), + NodeRef::GrantRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDefaultPrivilegesStmt(n) => n.unpack(&mut self.vec), + NodeRef::ClosePortalStmt(n) => n.unpack(&mut self.vec), + NodeRef::ClusterStmt(n) => n.unpack(&mut self.vec), + NodeRef::CopyStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateStmt(n) => n.unpack(&mut self.vec), + NodeRef::DefineStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropStmt(n) => n.unpack(&mut self.vec), + NodeRef::TruncateStmt(n) => n.unpack(&mut self.vec), + NodeRef::CommentStmt(n) => n.unpack(&mut self.vec), + NodeRef::FetchStmt(n) => n.unpack(&mut self.vec), + NodeRef::IndexStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateFunctionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterFunctionStmt(n) => n.unpack(&mut self.vec), + NodeRef::DoStmt(n) => n.unpack(&mut self.vec), + NodeRef::RenameStmt(n) => n.unpack(&mut self.vec), + NodeRef::RuleStmt(n) => n.unpack(&mut self.vec), + NodeRef::NotifyStmt(n) => n.unpack(&mut self.vec), + NodeRef::ListenStmt(n) => n.unpack(&mut self.vec), + NodeRef::UnlistenStmt(n) => n.unpack(&mut self.vec), + NodeRef::TransactionStmt(n) => n.unpack(&mut self.vec), + NodeRef::ViewStmt(n) => n.unpack(&mut self.vec), + NodeRef::LoadStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateDomainStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatedbStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropdbStmt(n) => n.unpack(&mut self.vec), + NodeRef::VacuumStmt(n) => n.unpack(&mut self.vec), + NodeRef::ExplainStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTableAsStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSeqStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSeqStmt(n) => n.unpack(&mut self.vec), + NodeRef::VariableSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::VariableShowStmt(n) => n.unpack(&mut self.vec), + NodeRef::DiscardStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePlangStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::LockStmt(n) => n.unpack(&mut self.vec), + NodeRef::ConstraintsSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReindexStmt(n) => n.unpack(&mut self.vec), + NodeRef::CheckPointStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterRoleSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateConversionStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateCastStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpClassStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpFamilyStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOpFamilyStmt(n) => n.unpack(&mut self.vec), + NodeRef::PrepareStmt(n) => n.unpack(&mut self.vec), + NodeRef::ExecuteStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeallocateStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeclareCursorStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTableSpaceStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropTableSpaceStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterObjectDependsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterObjectSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOwnerStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOperatorStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTypeStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropOwnedStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReassignOwnedStmt(n) => n.unpack(&mut self.vec), + NodeRef::CompositeTypeStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateEnumStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateRangeStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterEnumStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTsdictionaryStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTsconfigurationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateFdwStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterFdwStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateForeignServerStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterForeignServerStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableSpaceOptionsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableMoveAllStmt(n) => n.unpack(&mut self.vec), + NodeRef::SecLabelStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateForeignTableStmt(n) => n.unpack(&mut self.vec), + NodeRef::ImportForeignSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateExtensionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterExtensionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterExtensionContentsStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateEventTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterEventTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::RefreshMatViewStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReplicaIdentityStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSystemStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePolicyStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterPolicyStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTransformStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateAmStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePublicationStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterPublicationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateStatsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterCollationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CallStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterStatsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AExpr(n) => n.unpack(&mut self.vec), + NodeRef::ColumnRef(n) => n.unpack(&mut self.vec), + NodeRef::ParamRef(n) => n.unpack(&mut self.vec), + NodeRef::AConst(n) => n.unpack(&mut self.vec), + NodeRef::FuncCall(n) => n.unpack(&mut self.vec), + NodeRef::AStar(n) => n.unpack(&mut self.vec), + NodeRef::AIndices(n) => n.unpack(&mut self.vec), + NodeRef::AIndirection(n) => n.unpack(&mut self.vec), + NodeRef::AArrayExpr(n) => n.unpack(&mut self.vec), + NodeRef::ResTarget(n) => n.unpack(&mut self.vec), + NodeRef::MultiAssignRef(n) => n.unpack(&mut self.vec), + NodeRef::TypeCast(n) => n.unpack(&mut self.vec), + NodeRef::CollateClause(n) => n.unpack(&mut self.vec), + NodeRef::SortBy(n) => n.unpack(&mut self.vec), + NodeRef::WindowDef(n) => n.unpack(&mut self.vec), + NodeRef::RangeSubselect(n) => n.unpack(&mut self.vec), + NodeRef::RangeFunction(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableSample(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableFunc(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableFuncCol(n) => n.unpack(&mut self.vec), + NodeRef::TypeName(n) => n.unpack(&mut self.vec), + NodeRef::ColumnDef(n) => n.unpack(&mut self.vec), + NodeRef::IndexElem(n) => n.unpack(&mut self.vec), + NodeRef::Constraint(n) => n.unpack(&mut self.vec), + NodeRef::DefElem(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblEntry(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblFunction(n) => n.unpack(&mut self.vec), + NodeRef::TableSampleClause(n) => n.unpack(&mut self.vec), + NodeRef::WithCheckOption(n) => n.unpack(&mut self.vec), + NodeRef::SortGroupClause(n) => n.unpack(&mut self.vec), + NodeRef::GroupingSet(n) => n.unpack(&mut self.vec), + NodeRef::WindowClause(n) => n.unpack(&mut self.vec), + NodeRef::ObjectWithArgs(n) => n.unpack(&mut self.vec), + NodeRef::AccessPriv(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpClassItem(n) => n.unpack(&mut self.vec), + NodeRef::TableLikeClause(n) => n.unpack(&mut self.vec), + NodeRef::FunctionParameter(n) => n.unpack(&mut self.vec), + NodeRef::LockingClause(n) => n.unpack(&mut self.vec), + NodeRef::RowMarkClause(n) => n.unpack(&mut self.vec), + NodeRef::XmlSerialize(n) => n.unpack(&mut self.vec), + NodeRef::WithClause(n) => n.unpack(&mut self.vec), + NodeRef::InferClause(n) => n.unpack(&mut self.vec), + NodeRef::OnConflictClause(n) => n.unpack(&mut self.vec), + NodeRef::CommonTableExpr(n) => n.unpack(&mut self.vec), + NodeRef::RoleSpec(n) => n.unpack(&mut self.vec), + NodeRef::TriggerTransition(n) => n.unpack(&mut self.vec), + NodeRef::PartitionElem(n) => n.unpack(&mut self.vec), + NodeRef::PartitionSpec(n) => n.unpack(&mut self.vec), + NodeRef::PartitionBoundSpec(n) => n.unpack(&mut self.vec), + NodeRef::PartitionRangeDatum(n) => n.unpack(&mut self.vec), + NodeRef::PartitionCmd(n) => n.unpack(&mut self.vec), + NodeRef::VacuumRelation(n) => n.unpack(&mut self.vec), + NodeRef::InlineCodeBlock(n) => n.unpack(&mut self.vec), + NodeRef::CallContext(n) => n.unpack(&mut self.vec), + NodeRef::Integer(n) => n.unpack(&mut self.vec), + NodeRef::Float(n) => n.unpack(&mut self.vec), + NodeRef::Boolean(n) => n.unpack(&mut self.vec), + NodeRef::String(n) => n.unpack(&mut self.vec), + NodeRef::BitString(n) => n.unpack(&mut self.vec), + NodeRef::List(n) => n.unpack(&mut self.vec), + NodeRef::IntList(n) => n.unpack(&mut self.vec), + NodeRef::OidList(n) => n.unpack(&mut self.vec), + NodeRef::MergeStmt(n) => n.unpack(&mut self.vec), + NodeRef::MergeAction(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseRefreshCollStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReturnStmt(n) => n.unpack(&mut self.vec), + NodeRef::PlassignStmt(n) => n.unpack(&mut self.vec), + NodeRef::StatsElem(n) => n.unpack(&mut self.vec), + NodeRef::CtesearchClause(n) => n.unpack(&mut self.vec), + NodeRef::CtecycleClause(n) => n.unpack(&mut self.vec), + NodeRef::MergeWhenClause(n) => n.unpack(&mut self.vec), + NodeRef::PublicationObjSpec(n) => n.unpack(&mut self.vec), + NodeRef::PublicationTable(n) => n.unpack(&mut self.vec), + NodeRef::JsonFormat(n) => n.unpack(&mut self.vec), + NodeRef::JsonReturning(n) => n.unpack(&mut self.vec), + NodeRef::JsonValueExpr(n) => n.unpack(&mut self.vec), + NodeRef::JsonConstructorExpr(n) => n.unpack(&mut self.vec), + NodeRef::JsonIsPredicate(n) => n.unpack(&mut self.vec), + NodeRef::JsonOutput(n) => n.unpack(&mut self.vec), + NodeRef::JsonKeyValue(n) => n.unpack(&mut self.vec), + NodeRef::JsonObjectConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayQueryConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonAggConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonObjectAgg(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayAgg(n) => n.unpack(&mut self.vec), + NodeRef::RtepermissionInfo(n) => n.unpack(&mut self.vec), + }; + Some(node) + } else { + None + } + } +} + +trait Unpack<'a> { + fn unpack(&'a self, vec: &mut VecDeque>); +} + +impl<'a> Unpack<'a> for protobuf::Node { + fn unpack(&'a self, vec: &mut VecDeque>) { + if let Some(ref e) = self.node { + vec.push_back(e.to_ref()); + } + } +} + +include!(concat!(env!("OUT_DIR"), "/unpack.rs")); + +impl<'a> protobuf::ParseResult { + pub fn iter(&'a self) -> impl Iterator> { + NodesIter { vec: self.stmts.iter().map(|stmt| stmt.to_ref()).collect() } + } +} + impl<'a> NodeRef<'a> { // TODO: `deparseStmt` in pg_query_deparse.c panics on unexpected nodes. instead, return a Rust error pub fn deparse(&self) -> Result { @@ -265,6 +558,10 @@ impl<'a> NodeRef<'a> { }) } + pub fn iter(&self) -> impl Iterator> { + NodesIter { vec: VecDeque::from([*self]) } + } + pub fn to_enum(&self) -> NodeEnum { match self { NodeRef::Alias(n) => NodeEnum::Alias((*n).clone()), diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index 57e5284..9abbb1e 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -73,6 +73,7 @@ fn it_parses_real_queries() { fn it_parses_empty_queries() { let result = parse("-- nothing").unwrap(); assert_eq!(result.protobuf.nodes().len(), 0); + assert_eq!(result.protobuf.iter().collect::>().len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.warnings.len(), 0); assert_eq!(result.statement_types().len(), 0); @@ -81,12 +82,26 @@ fn it_parses_empty_queries() { #[test] fn it_parses_floats_with_leading_dot() { let result = parse("SELECT .1").unwrap(); - let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); - let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); - let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); - let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); - assert_eq!(float.fval, ".1"); - assert_eq!(a_const.location, 7); + { + let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); + let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); + let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); + let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); + assert_eq!(float.fval, ".1"); + assert_eq!(a_const.location, 7); + } + + { + let all_nodes = result.protobuf.iter().collect::>(); + assert_eq!(all_nodes.len(), 4); + println!("{:?}", all_nodes); + let _select = cast!(all_nodes[1], NodeRef::SelectStmt); + let _target = cast!(all_nodes[2], NodeRef::ResTarget); + let a_const = cast!(all_nodes[3], NodeRef::AConst); + let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); + assert_eq!(float.fval, ".1"); + assert_eq!(a_const.location, 7); + } } #[test] @@ -1010,6 +1025,28 @@ fn it_parses_WITH() { assert_eq!(result.statement_types(), ["SelectStmt"]); } +#[test] +fn it_parses_INSERT_SELECT() { + let result = parse("INSERT INTO a SELECT x FROM b").unwrap(); + assert_eq!(sorted(result.tables()).collect::>(), ["a", "b"]); + let all_nodes = result.protobuf.iter().collect::>(); + assert_eq!(all_nodes.len(), 7); +} + +#[test] +fn get_params() { + let result = parse( + "WITH a AS (SELECT x FROM b WHERE y = $1 AND z = (SELECT u FROM c WHERE w = $2)) INSERT INTO d (e, f) SELECT $3, x FROM a WHERE x > $4", + ) + .unwrap(); + let params: Vec<_> = sorted(result.protobuf.iter().filter_map(|node| match node { + NodeRef::ParamRef(x) => Some(x.number), + _ => None, + })) + .collect(); + assert_eq!(params, [1, 2, 3, 4]); +} + #[test] fn it_parses_multi_line_functions() { let sql = "CREATE OR REPLACE FUNCTION thing(parameter_thing text)