diff --git a/src/parse_result.rs b/src/parse_result.rs index 3f506f0..8c84788 100644 --- a/src/parse_result.rs +++ b/src/parse_result.rs @@ -142,20 +142,21 @@ impl ParseResult { } }, NodeRef::ColumnRef(c) => { + if !has_filter_columns { + continue; + } let f: Vec<&String> = c .fields .iter() - .filter_map(|n| { - n.node.as_ref().and_then(|n| if let NodeEnum::AStar(_) = n { None } else { Some(&cast!(n, NodeEnum::String).sval) }) + .filter_map(|n| match n.node.as_ref() { + Some(NodeEnum::String(s)) => Some(&s.sval), + _ => None, }) .rev() .collect(); - if f.len() == 0 || !has_filter_columns { - continue; + if f.len() > 0 { + filter_columns.insert((f.get(1).cloned().cloned(), f[0].to_string())); } - let column = f[0]; - let table = f.get(1).map(|t| t.to_string()); - filter_columns.insert((table, column.to_string())); } _ => (), } diff --git a/tests/filter_column_tests.rs b/tests/filter_column_tests.rs index 9cc85ae..9f4e9a7 100644 --- a/tests/filter_column_tests.rs +++ b/tests/filter_column_tests.rs @@ -10,35 +10,35 @@ use pg_query::parse; fn it_finds_unqualified_names() { let result = parse("SELECT * FROM x WHERE y = $1 AND z = 1").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_qualified_names() { let result = parse("SELECT * FROM x WHERE x.y = $1 AND x.z = 1").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); } #[test] fn it_traverses_into_ctes() { let result = parse("WITH a AS (SELECT * FROM x WHERE x.y = $1 AND x.z = 1) SELECT * FROM a WHERE b = 5").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "b".to_string()), (Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(None, "b".into()), (Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); } #[test] fn it_recognizes_boolean_tests() { let result = parse("SELECT * FROM x WHERE x.y IS TRUE AND x.z IS NOT FALSE").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); } #[test] fn it_recognizes_null_tests() { let result = parse("SELECT * FROM x WHERE x.y IS NULL AND x.z IS NOT NULL").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); } #[test] @@ -47,7 +47,7 @@ fn it_finds_coalesce_argument_names() { let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); assert_eq!( filter_columns, - [(Some("x".to_string()), "y".to_string()), (Some("z".to_string()), "a".to_string()), (Some("z".to_string()), "b".to_string())] + [(Some("x".into()), "y".into()), (Some("z".into()), "a".into()), (Some("z".into()), "b".into())] ); } @@ -55,54 +55,54 @@ fn it_finds_coalesce_argument_names() { fn it_finds_unqualified_names_in_union_query() { let result = parse("SELECT * FROM x where y = $1 UNION SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_unqualified_names_in_union_all_query() { let result = parse("SELECT * FROM x where y = $1 UNION ALL SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_unqualified_names_in_except_query() { let result = parse("SELECT * FROM x where y = $1 EXCEPT SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_unqualified_names_in_except_all_query() { let result = parse("SELECT * FROM x where y = $1 EXCEPT ALL SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_unqualified_names_in_intersect_query() { let result = parse("SELECT * FROM x where y = $1 INTERSECT SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_finds_unqualified_names_in_intersect_all_query() { let result = parse("SELECT * FROM x where y = $1 INTERSECT ALL SELECT * FROM x where z = $2").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(None, "y".to_string()), (None, "z".to_string())]); + assert_eq!(filter_columns, [(None, "y".into()), (None, "z".into())]); } #[test] fn it_ignores_target_list_columns() { let result = parse("SELECT a, y, z FROM x WHERE x.y = $1 AND x.z = 1").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); } #[test] fn it_ignores_order_by_columns() { let result = parse("SELECT * FROM x WHERE x.y = $1 AND x.z = 1 ORDER BY a, b").unwrap(); let filter_columns: Vec<(Option, String)> = sorted(result.filter_columns).collect(); - assert_eq!(filter_columns, [(Some("x".to_string()), "y".to_string()), (Some("x".to_string()), "z".to_string())]); + assert_eq!(filter_columns, [(Some("x".into()), "y".into()), (Some("x".into()), "z".into())]); }