From d7b0077d6e3a45cc51669c86863671bf66036e41 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Thu, 18 Jul 2024 14:55:32 -0700 Subject: [PATCH] Fix incorrect pushdown involving aggreagations and unnest Complex conjuncts that might fail were being reordered before simpler conjuncts. --- .../optimizations/PredicatePushDown.java | 30 +++++++------ .../io/trino/sql/query/TestIssue22731.java | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/sql/query/TestIssue22731.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index f84e2e4a252e..de15a6e88fb8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -1492,8 +1492,15 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys()); + + // Add the equality predicates back in + EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(groupingKeys); + pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); + postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + // Sort non-equality predicates by those that can be pushed down and those that cannot EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { if (node.getGroupIdSymbol().isPresent() && extractUnique(conjunct).contains(node.getGroupIdSymbol().get())) { // aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we @@ -1513,12 +1520,6 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) .forEach(postUnnestConjuncts::add); inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate); - // Sort non-equality predicates by those that can be pushed down and those that cannot Set replicatedSymbols = ImmutableSet.copyOf(node.getReplicateSymbols()); + + // Add the equality predicates back in + EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(replicatedSymbols); + pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); + postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + // Sort non-equality predicates by those that can be pushed down and those that cannot EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression rewrittenConjunct = equalityInference.rewrite(conjunct, replicatedSymbols); if (rewrittenConjunct != null) { @@ -1566,12 +1574,6 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) } }); - // Add the equality predicates back in - EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(replicatedSymbols); - pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); - postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); - postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); - PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestIssue22731.java b/core/trino-main/src/test/java/io/trino/sql/query/TestIssue22731.java new file mode 100644 index 000000000000..319515a21863 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestIssue22731.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TestIssue22731 +{ + @Test + public void test() + { + try (QueryAssertions assertions = new QueryAssertions()) { + assertThat(assertions.query( + """ + WITH t(a) as ( + VALUES ARRAY[ARRAY['a']] + ), + u as ( + SELECT + e[cardinality(e)] AS v1, + cardinality(e) AS v2 + FROM t CROSS JOIN UNNEST(t.a) AS z(e) + GROUP BY e + ) + SELECT * + FROM u + WHERE v2 = 2 AND v1 = '' + """)) + .returnsEmptyResult(); + } + } +}