diff --git a/Cargo.lock b/Cargo.lock index 0a3d1ca..cae3e6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,7 +54,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.31", + "syn 2.0.32", "which", ] @@ -368,9 +368,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "peeking_take_while" @@ -398,10 +398,12 @@ dependencies = [ "easy-parallel", "fs_extra", "glob", + "heck", "itertools", "pretty_assertions", - "prost", + "prost 0.13.1", "prost-build", + "prost-types 0.13.1", "regex", "serde", "serde_json", @@ -425,7 +427,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8832c0f9be7e3cae60727e6256cfd2cd3c3e2b6cd5dad4190ecb2fd658c9030b" dependencies = [ "proc-macro2", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -444,7 +446,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71adf41db68aa0daaefc69bb30bcd68ded9b9abaad5d1fbb6304c4fb390e083e" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.10.1", +] + +[[package]] +name = "prost" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc" +dependencies = [ + "bytes", + "prost-derive 0.13.1", ] [[package]] @@ -462,8 +474,8 @@ dependencies = [ "log", "multimap", "petgraph", - "prost", - "prost-types", + "prost 0.10.4", + "prost-types 0.10.1", "regex", "tempfile", "which", @@ -482,6 +494,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prost-derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "prost-types" version = "0.10.1" @@ -489,7 +514,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d0a014229361011dc8e69c8a1ec6c2e8d0f2af7c91e3ea3f5b2170298461e68" dependencies = [ "bytes", - "prost", + "prost 0.10.4", +] + +[[package]] +name = "prost-types" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2" +dependencies = [ + "prost 0.13.1", ] [[package]] @@ -610,7 +644,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -643,9 +677,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.31" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -693,7 +727,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 23ae4d9..55eaabe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ repository = "https://github.com/pganalyze/pg_query.rs" [dependencies] itertools = "0.10.3" -prost = "0.10.4" +prost = "0.13.0" serde = { version = "1.0.139", features = ["derive"] } serde_json = "1.0.82" thiserror = "1.0.31" @@ -21,10 +21,13 @@ thiserror = "1.0.31" [build-dependencies] bindgen = "0.66.1" clippy = { version = "0.0.302", optional = true } +prost = "0.13.0" prost-build = "0.10.4" +prost-types = "0.13.0" fs_extra = "1.2.0" cc = "1.0.83" glob = "0.3.1" +heck = "0.4" [dev-dependencies] easy-parallel = "3.2.0" diff --git a/build.rs b/build.rs index 57d3664..353c9f3 100644 --- a/build.rs +++ b/build.rs @@ -3,12 +3,127 @@ use fs_extra::dir::CopyOptions; use glob::glob; +use prost::Message; +use prost_types::field_descriptor_proto::Type; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::env; use std::path::{Path, PathBuf}; +//use prost_build::ident::to_snake; +use heck::ToUpperCamelCase; static SOURCE_DIRECTORY: &str = "libpg_query"; static LIBRARY_NAME: &str = "pg_query"; +type Cardinality = prost_types::field_descriptor_proto::Label; + +struct Edge { + field: String, + message: usize, + cardinality: Cardinality, +} + +/// Represents a directed labeled multigraph of Message types. Each vertex represents a message +/// type. An edge A->B is a tuple (field_name: String, type: FieldType), that states that +/// Message A has a field (with name equal to `field_name`) of Message type B. +struct MessageGraph { + messages: HashMap, + + /// For each vertex A, the list of edges from A to other vertices, and a set of vertices B such that there is at least one edge B->A + edges: Vec<(String, Vec, BTreeSet)>, +} + +impl MessageGraph { + fn new() -> Self { + Self { messages: HashMap::new(), edges: Vec::new() } + } + + fn id_for(&mut self, type_name: &str) -> usize { + if let Some(id) = self.messages.get(type_name) { + *id + } else { + let id = self.edges.len(); + self.edges.push((type_name.to_string(), Vec::new(), BTreeSet::new())); + self.messages.insert(type_name.to_string(), id); + id + } + } + + fn make(&mut self, fds: prost_types::FileDescriptorSet) { + for fd in fds.file { + let package = fd.package().to_string(); + for msg in fd.message_type { + let full_name = format!(".{}.{}", package, msg.name()); + let id = self.id_for(&full_name); + + // buf.push_str(&format!("Message: {}\n", msg.name())); + // let my_edges: Vec<(String, FieldType)> = Vec::new(); + // + let mut fields: HashSet = HashSet::new(); + + if msg.name() != "Node" && msg.name() != "A_Const" { + for field in &msg.field { + if field.r#type() != Type::Message { + continue; + } + + if field.oneof_index.is_some() { + panic!(" No support for enums: field {} of message {}", field.name(), msg.name()); + } + + if !fields.insert(field.name().to_string()) { + panic!(" duplicate field: {}", field.name()); + } + + let message_id = self.id_for(field.type_name()); + self.edges[id].1.push(Edge { field: field.name().to_string(), message: message_id, cardinality: field.label() }); + self.edges[message_id].2.insert(id); + } + } + } + } + } + + fn filter_incoming(&self, id: usize, filter: &mut Vec) { + if !filter[id] { + filter[id] = true; + for nb in self.edges[id].2.iter() { + self.filter_incoming(*nb, filter); + } + } + } + + fn write(&self, buf: &mut String) { + let mut filter = vec![false; self.messages.len()]; + self.filter_incoming(*self.messages.get(".pg_query.Node").unwrap(), &mut filter); + for (id, (name, edges, _incoming)) in self.edges.iter().enumerate() { + let filtered = filter[id]; + let short_name = &name[name.rfind(".").unwrap() + 1..].to_upper_camel_case(); + if short_name == "Node" || short_name == "ParseResult" || short_name == "ScanResult" || short_name == "ScanToken" { + continue; + } + + buf.push_str(&format!("impl<'a> protobuf::{} {{\n", short_name)); + if filtered && edges.iter().any(|e| filter[e.message]) { + buf.push_str(" fn unpack(&'a self, vec: &mut VecDeque>) {\n"); + for edge in edges.iter() { + if filter[edge.message] { + match edge.cardinality { + Cardinality::Repeated => buf.push_str(&format!(" self.{}.iter().for_each(|n| n.unpack(vec));\n", edge.field)), + Cardinality::Required => buf.push_str(&format!(" vec.push_back(self.{});\n", edge.field)), + Cardinality::Optional => { + buf.push_str(&format!(" if let Some(ref e) = self.{} {{ e.unpack(vec); }}\n", edge.field)) + } + } + } + } + buf.push_str(" }\n}\n\n"); + } else { + buf.push_str(" fn unpack(&'a self, _vec: &mut VecDeque>) { }\n}\n"); + } + } + } +} + fn main() -> Result<(), Box> { let out_dir = PathBuf::from(env::var("OUT_DIR")?); let build_path = Path::new(".").join(SOURCE_DIRECTORY); @@ -65,6 +180,18 @@ fn main() -> Result<(), Box> { .write_to_file(out_dir.join("bindings.rs"))?; // Generate the protobuf definition + let mut config = prost_build::Config::new(); + let fds_path = out_dir.join("./file_descriptor_set.bin"); + config.file_descriptor_set_path(fds_path.clone()); + config.compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; + + let mut buf = String::new(); + let fds = prost_types::FileDescriptorSet::decode(std::fs::read(fds_path)?.as_slice())?; + let mut graph = MessageGraph::new(); + graph.make(fds); + graph.write(&mut buf); + std::fs::write(out_dir.join("./unpack.rs"), buf)?; + prost_build::compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; Ok(()) diff --git a/src/bindings.rs b/src/bindings.rs index 4f98483..74701a4 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -3,4 +3,7 @@ #![allow(non_snake_case)] #![allow(unused)] #![allow(clippy::all)] + +use prost::Message; + include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/src/node_enum.rs b/src/node_enum.rs index 665308c..2178182 100644 --- a/src/node_enum.rs +++ b/src/node_enum.rs @@ -25,6 +25,7 @@ impl NodeEnum { while !iter.is_empty() { let (node, depth, context, has_filter_columns) = iter.remove(0); let depth = depth + 1; + match node { // // The following statement types do not modify tables diff --git a/src/node_ref.rs b/src/node_ref.rs index 41a97d6..9f0b704 100644 --- a/src/node_ref.rs +++ b/src/node_ref.rs @@ -1,4 +1,5 @@ use crate::*; +use std::collections::VecDeque; #[derive(Debug, Copy, Clone)] pub enum NodeRef<'a> { @@ -256,6 +257,294 @@ pub enum NodeRef<'a> { RtepermissionInfo(&'a protobuf::RtePermissionInfo), } +/// Allows to walk through a tree of NodeRefs. +struct NodesIter<'a> { + /// A deque of NodeRefs that have still to be expanded + vec: std::collections::VecDeque>, +} + +impl<'a> Iterator for NodesIter<'a> { + type Item = NodeRef<'a>; + + fn next(&mut self) -> Option { + if let Some(node) = self.vec.pop_front() { + match node { + NodeRef::Alias(n) => n.unpack(&mut self.vec), + NodeRef::RangeVar(n) => n.unpack(&mut self.vec), + NodeRef::TableFunc(n) => n.unpack(&mut self.vec), + NodeRef::Var(n) => n.unpack(&mut self.vec), + NodeRef::Param(n) => n.unpack(&mut self.vec), + NodeRef::Aggref(n) => n.unpack(&mut self.vec), + NodeRef::GroupingFunc(n) => n.unpack(&mut self.vec), + NodeRef::WindowFunc(n) => n.unpack(&mut self.vec), + NodeRef::SubscriptingRef(n) => n.unpack(&mut self.vec), + NodeRef::FuncExpr(n) => n.unpack(&mut self.vec), + NodeRef::NamedArgExpr(n) => n.unpack(&mut self.vec), + NodeRef::OpExpr(n) => n.unpack(&mut self.vec), + NodeRef::DistinctExpr(n) => n.unpack(&mut self.vec), + NodeRef::NullIfExpr(n) => n.unpack(&mut self.vec), + NodeRef::ScalarArrayOpExpr(n) => n.unpack(&mut self.vec), + NodeRef::BoolExpr(n) => n.unpack(&mut self.vec), + NodeRef::SubLink(n) => n.unpack(&mut self.vec), + NodeRef::SubPlan(n) => n.unpack(&mut self.vec), + NodeRef::AlternativeSubPlan(n) => n.unpack(&mut self.vec), + NodeRef::FieldSelect(n) => n.unpack(&mut self.vec), + NodeRef::FieldStore(n) => n.unpack(&mut self.vec), + NodeRef::RelabelType(n) => n.unpack(&mut self.vec), + NodeRef::CoerceViaIo(n) => n.unpack(&mut self.vec), + NodeRef::ArrayCoerceExpr(n) => n.unpack(&mut self.vec), + NodeRef::ConvertRowtypeExpr(n) => n.unpack(&mut self.vec), + NodeRef::CollateExpr(n) => n.unpack(&mut self.vec), + NodeRef::CaseExpr(n) => n.unpack(&mut self.vec), + NodeRef::CaseWhen(n) => n.unpack(&mut self.vec), + NodeRef::CaseTestExpr(n) => n.unpack(&mut self.vec), + NodeRef::ArrayExpr(n) => n.unpack(&mut self.vec), + NodeRef::RowExpr(n) => n.unpack(&mut self.vec), + NodeRef::RowCompareExpr(n) => n.unpack(&mut self.vec), + NodeRef::CoalesceExpr(n) => n.unpack(&mut self.vec), + NodeRef::MinMaxExpr(n) => n.unpack(&mut self.vec), + NodeRef::SqlvalueFunction(n) => n.unpack(&mut self.vec), + NodeRef::XmlExpr(n) => n.unpack(&mut self.vec), + NodeRef::NullTest(n) => n.unpack(&mut self.vec), + NodeRef::BooleanTest(n) => n.unpack(&mut self.vec), + NodeRef::CoerceToDomain(n) => n.unpack(&mut self.vec), + NodeRef::CoerceToDomainValue(n) => n.unpack(&mut self.vec), + NodeRef::SetToDefault(n) => n.unpack(&mut self.vec), + NodeRef::CurrentOfExpr(n) => n.unpack(&mut self.vec), + NodeRef::NextValueExpr(n) => n.unpack(&mut self.vec), + NodeRef::InferenceElem(n) => n.unpack(&mut self.vec), + NodeRef::TargetEntry(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblRef(n) => n.unpack(&mut self.vec), + NodeRef::JoinExpr(n) => n.unpack(&mut self.vec), + NodeRef::FromExpr(n) => n.unpack(&mut self.vec), + NodeRef::OnConflictExpr(n) => n.unpack(&mut self.vec), + NodeRef::IntoClause(n) => n.unpack(&mut self.vec), + NodeRef::RawStmt(n) => n.unpack(&mut self.vec), + NodeRef::Query(n) => n.unpack(&mut self.vec), + NodeRef::InsertStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeleteStmt(n) => n.unpack(&mut self.vec), + NodeRef::UpdateStmt(n) => n.unpack(&mut self.vec), + NodeRef::SelectStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableCmd(n) => n.unpack(&mut self.vec), + NodeRef::AlterDomainStmt(n) => n.unpack(&mut self.vec), + NodeRef::SetOperationStmt(n) => n.unpack(&mut self.vec), + NodeRef::GrantStmt(n) => n.unpack(&mut self.vec), + NodeRef::GrantRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDefaultPrivilegesStmt(n) => n.unpack(&mut self.vec), + NodeRef::ClosePortalStmt(n) => n.unpack(&mut self.vec), + NodeRef::ClusterStmt(n) => n.unpack(&mut self.vec), + NodeRef::CopyStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateStmt(n) => n.unpack(&mut self.vec), + NodeRef::DefineStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropStmt(n) => n.unpack(&mut self.vec), + NodeRef::TruncateStmt(n) => n.unpack(&mut self.vec), + NodeRef::CommentStmt(n) => n.unpack(&mut self.vec), + NodeRef::FetchStmt(n) => n.unpack(&mut self.vec), + NodeRef::IndexStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateFunctionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterFunctionStmt(n) => n.unpack(&mut self.vec), + NodeRef::DoStmt(n) => n.unpack(&mut self.vec), + NodeRef::RenameStmt(n) => n.unpack(&mut self.vec), + NodeRef::RuleStmt(n) => n.unpack(&mut self.vec), + NodeRef::NotifyStmt(n) => n.unpack(&mut self.vec), + NodeRef::ListenStmt(n) => n.unpack(&mut self.vec), + NodeRef::UnlistenStmt(n) => n.unpack(&mut self.vec), + NodeRef::TransactionStmt(n) => n.unpack(&mut self.vec), + NodeRef::ViewStmt(n) => n.unpack(&mut self.vec), + NodeRef::LoadStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateDomainStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatedbStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropdbStmt(n) => n.unpack(&mut self.vec), + NodeRef::VacuumStmt(n) => n.unpack(&mut self.vec), + NodeRef::ExplainStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTableAsStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSeqStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSeqStmt(n) => n.unpack(&mut self.vec), + NodeRef::VariableSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::VariableShowStmt(n) => n.unpack(&mut self.vec), + NodeRef::DiscardStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePlangStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropRoleStmt(n) => n.unpack(&mut self.vec), + NodeRef::LockStmt(n) => n.unpack(&mut self.vec), + NodeRef::ConstraintsSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReindexStmt(n) => n.unpack(&mut self.vec), + NodeRef::CheckPointStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterRoleSetStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateConversionStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateCastStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpClassStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpFamilyStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOpFamilyStmt(n) => n.unpack(&mut self.vec), + NodeRef::PrepareStmt(n) => n.unpack(&mut self.vec), + NodeRef::ExecuteStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeallocateStmt(n) => n.unpack(&mut self.vec), + NodeRef::DeclareCursorStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTableSpaceStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropTableSpaceStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterObjectDependsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterObjectSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOwnerStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterOperatorStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTypeStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropOwnedStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReassignOwnedStmt(n) => n.unpack(&mut self.vec), + NodeRef::CompositeTypeStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateEnumStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateRangeStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterEnumStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTsdictionaryStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTsconfigurationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateFdwStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterFdwStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateForeignServerStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterForeignServerStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropUserMappingStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableSpaceOptionsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterTableMoveAllStmt(n) => n.unpack(&mut self.vec), + NodeRef::SecLabelStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateForeignTableStmt(n) => n.unpack(&mut self.vec), + NodeRef::ImportForeignSchemaStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateExtensionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterExtensionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterExtensionContentsStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateEventTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterEventTrigStmt(n) => n.unpack(&mut self.vec), + NodeRef::RefreshMatViewStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReplicaIdentityStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSystemStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePolicyStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterPolicyStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateTransformStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateAmStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreatePublicationStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterPublicationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::DropSubscriptionStmt(n) => n.unpack(&mut self.vec), + NodeRef::CreateStatsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterCollationStmt(n) => n.unpack(&mut self.vec), + NodeRef::CallStmt(n) => n.unpack(&mut self.vec), + NodeRef::AlterStatsStmt(n) => n.unpack(&mut self.vec), + NodeRef::AExpr(n) => n.unpack(&mut self.vec), + NodeRef::ColumnRef(n) => n.unpack(&mut self.vec), + NodeRef::ParamRef(n) => n.unpack(&mut self.vec), + NodeRef::AConst(n) => n.unpack(&mut self.vec), + NodeRef::FuncCall(n) => n.unpack(&mut self.vec), + NodeRef::AStar(n) => n.unpack(&mut self.vec), + NodeRef::AIndices(n) => n.unpack(&mut self.vec), + NodeRef::AIndirection(n) => n.unpack(&mut self.vec), + NodeRef::AArrayExpr(n) => n.unpack(&mut self.vec), + NodeRef::ResTarget(n) => n.unpack(&mut self.vec), + NodeRef::MultiAssignRef(n) => n.unpack(&mut self.vec), + NodeRef::TypeCast(n) => n.unpack(&mut self.vec), + NodeRef::CollateClause(n) => n.unpack(&mut self.vec), + NodeRef::SortBy(n) => n.unpack(&mut self.vec), + NodeRef::WindowDef(n) => n.unpack(&mut self.vec), + NodeRef::RangeSubselect(n) => n.unpack(&mut self.vec), + NodeRef::RangeFunction(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableSample(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableFunc(n) => n.unpack(&mut self.vec), + NodeRef::RangeTableFuncCol(n) => n.unpack(&mut self.vec), + NodeRef::TypeName(n) => n.unpack(&mut self.vec), + NodeRef::ColumnDef(n) => n.unpack(&mut self.vec), + NodeRef::IndexElem(n) => n.unpack(&mut self.vec), + NodeRef::Constraint(n) => n.unpack(&mut self.vec), + NodeRef::DefElem(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblEntry(n) => n.unpack(&mut self.vec), + NodeRef::RangeTblFunction(n) => n.unpack(&mut self.vec), + NodeRef::TableSampleClause(n) => n.unpack(&mut self.vec), + NodeRef::WithCheckOption(n) => n.unpack(&mut self.vec), + NodeRef::SortGroupClause(n) => n.unpack(&mut self.vec), + NodeRef::GroupingSet(n) => n.unpack(&mut self.vec), + NodeRef::WindowClause(n) => n.unpack(&mut self.vec), + NodeRef::ObjectWithArgs(n) => n.unpack(&mut self.vec), + NodeRef::AccessPriv(n) => n.unpack(&mut self.vec), + NodeRef::CreateOpClassItem(n) => n.unpack(&mut self.vec), + NodeRef::TableLikeClause(n) => n.unpack(&mut self.vec), + NodeRef::FunctionParameter(n) => n.unpack(&mut self.vec), + NodeRef::LockingClause(n) => n.unpack(&mut self.vec), + NodeRef::RowMarkClause(n) => n.unpack(&mut self.vec), + NodeRef::XmlSerialize(n) => n.unpack(&mut self.vec), + NodeRef::WithClause(n) => n.unpack(&mut self.vec), + NodeRef::InferClause(n) => n.unpack(&mut self.vec), + NodeRef::OnConflictClause(n) => n.unpack(&mut self.vec), + NodeRef::CommonTableExpr(n) => n.unpack(&mut self.vec), + NodeRef::RoleSpec(n) => n.unpack(&mut self.vec), + NodeRef::TriggerTransition(n) => n.unpack(&mut self.vec), + NodeRef::PartitionElem(n) => n.unpack(&mut self.vec), + NodeRef::PartitionSpec(n) => n.unpack(&mut self.vec), + NodeRef::PartitionBoundSpec(n) => n.unpack(&mut self.vec), + NodeRef::PartitionRangeDatum(n) => n.unpack(&mut self.vec), + NodeRef::PartitionCmd(n) => n.unpack(&mut self.vec), + NodeRef::VacuumRelation(n) => n.unpack(&mut self.vec), + NodeRef::InlineCodeBlock(n) => n.unpack(&mut self.vec), + NodeRef::CallContext(n) => n.unpack(&mut self.vec), + NodeRef::Integer(n) => n.unpack(&mut self.vec), + NodeRef::Float(n) => n.unpack(&mut self.vec), + NodeRef::Boolean(n) => n.unpack(&mut self.vec), + NodeRef::String(n) => n.unpack(&mut self.vec), + NodeRef::BitString(n) => n.unpack(&mut self.vec), + NodeRef::List(n) => n.unpack(&mut self.vec), + NodeRef::IntList(n) => n.unpack(&mut self.vec), + NodeRef::OidList(n) => n.unpack(&mut self.vec), + NodeRef::MergeStmt(n) => n.unpack(&mut self.vec), + NodeRef::MergeAction(n) => n.unpack(&mut self.vec), + NodeRef::AlterDatabaseRefreshCollStmt(n) => n.unpack(&mut self.vec), + NodeRef::ReturnStmt(n) => n.unpack(&mut self.vec), + NodeRef::PlassignStmt(n) => n.unpack(&mut self.vec), + NodeRef::StatsElem(n) => n.unpack(&mut self.vec), + NodeRef::CtesearchClause(n) => n.unpack(&mut self.vec), + NodeRef::CtecycleClause(n) => n.unpack(&mut self.vec), + NodeRef::MergeWhenClause(n) => n.unpack(&mut self.vec), + NodeRef::PublicationObjSpec(n) => n.unpack(&mut self.vec), + NodeRef::PublicationTable(n) => n.unpack(&mut self.vec), + NodeRef::JsonFormat(n) => n.unpack(&mut self.vec), + NodeRef::JsonReturning(n) => n.unpack(&mut self.vec), + NodeRef::JsonValueExpr(n) => n.unpack(&mut self.vec), + NodeRef::JsonConstructorExpr(n) => n.unpack(&mut self.vec), + NodeRef::JsonIsPredicate(n) => n.unpack(&mut self.vec), + NodeRef::JsonOutput(n) => n.unpack(&mut self.vec), + NodeRef::JsonKeyValue(n) => n.unpack(&mut self.vec), + NodeRef::JsonObjectConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayQueryConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonAggConstructor(n) => n.unpack(&mut self.vec), + NodeRef::JsonObjectAgg(n) => n.unpack(&mut self.vec), + NodeRef::JsonArrayAgg(n) => n.unpack(&mut self.vec), + NodeRef::RtepermissionInfo(n) => n.unpack(&mut self.vec), + }; + Some(node) + } else { + None + } + } +} + +impl<'a> protobuf::Node { + fn unpack(&'a self, vec: &mut VecDeque>) { + if let Some(ref e) = self.node { + vec.push_back(e.to_ref()); + } + } +} + +include!(concat!(env!("OUT_DIR"), "/unpack.rs")); + +impl<'a> protobuf::ParseResult { + pub fn iter(&'a self) -> impl Iterator> { + NodesIter { vec: self.stmts.iter().map(|stmt| stmt.to_ref()).collect() } + } +} + impl<'a> NodeRef<'a> { // TODO: `deparseStmt` in pg_query_deparse.c panics on unexpected nodes. instead, return a Rust error pub fn deparse(&self) -> Result { @@ -265,6 +554,10 @@ impl<'a> NodeRef<'a> { }) } + pub fn iter(&self) -> impl Iterator> { + NodesIter { vec: VecDeque::from([*self]) } + } + pub fn to_enum(&self) -> NodeEnum { match self { NodeRef::Alias(n) => NodeEnum::Alias((*n).clone()), diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index 57e5284..9abbb1e 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -73,6 +73,7 @@ fn it_parses_real_queries() { fn it_parses_empty_queries() { let result = parse("-- nothing").unwrap(); assert_eq!(result.protobuf.nodes().len(), 0); + assert_eq!(result.protobuf.iter().collect::>().len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.warnings.len(), 0); assert_eq!(result.statement_types().len(), 0); @@ -81,12 +82,26 @@ fn it_parses_empty_queries() { #[test] fn it_parses_floats_with_leading_dot() { let result = parse("SELECT .1").unwrap(); - let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); - let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); - let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); - let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); - assert_eq!(float.fval, ".1"); - assert_eq!(a_const.location, 7); + { + let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); + let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); + let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); + let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); + assert_eq!(float.fval, ".1"); + assert_eq!(a_const.location, 7); + } + + { + let all_nodes = result.protobuf.iter().collect::>(); + assert_eq!(all_nodes.len(), 4); + println!("{:?}", all_nodes); + let _select = cast!(all_nodes[1], NodeRef::SelectStmt); + let _target = cast!(all_nodes[2], NodeRef::ResTarget); + let a_const = cast!(all_nodes[3], NodeRef::AConst); + let float = cast!(a_const.val.as_ref().unwrap(), Val::Fval); + assert_eq!(float.fval, ".1"); + assert_eq!(a_const.location, 7); + } } #[test] @@ -1010,6 +1025,28 @@ fn it_parses_WITH() { assert_eq!(result.statement_types(), ["SelectStmt"]); } +#[test] +fn it_parses_INSERT_SELECT() { + let result = parse("INSERT INTO a SELECT x FROM b").unwrap(); + assert_eq!(sorted(result.tables()).collect::>(), ["a", "b"]); + let all_nodes = result.protobuf.iter().collect::>(); + assert_eq!(all_nodes.len(), 7); +} + +#[test] +fn get_params() { + let result = parse( + "WITH a AS (SELECT x FROM b WHERE y = $1 AND z = (SELECT u FROM c WHERE w = $2)) INSERT INTO d (e, f) SELECT $3, x FROM a WHERE x > $4", + ) + .unwrap(); + let params: Vec<_> = sorted(result.protobuf.iter().filter_map(|node| match node { + NodeRef::ParamRef(x) => Some(x.number), + _ => None, + })) + .collect(); + assert_eq!(params, [1, 2, 3, 4]); +} + #[test] fn it_parses_multi_line_functions() { let sql = "CREATE OR REPLACE FUNCTION thing(parameter_thing text)