From 1b1325fdd375ae460c9da4ded525e15d0add6095 Mon Sep 17 00:00:00 2001 From: zombee0 Date: Fri, 10 Jan 2025 20:38:16 +0800 Subject: [PATCH] add ut Signed-off-by: zombee0 --- .../tree/lowcardinality/DecodeCollector.java | 13 +++++-- .../sql/plan/LowCardinalityTest2.java | 39 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/lowcardinality/DecodeCollector.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/lowcardinality/DecodeCollector.java index cd86488c50d70..4d5d38a39ba68 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/lowcardinality/DecodeCollector.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/lowcardinality/DecodeCollector.java @@ -14,6 +14,7 @@ package com.starrocks.sql.optimizer.rule.tree.lowcardinality; +import autovalue.shaded.com.google.common.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; @@ -59,6 +60,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.jetbrains.annotations.TestOnly; import java.util.Collection; import java.util.Collections; @@ -132,14 +134,14 @@ public class DecodeCollector extends OptExpressionVisitor { @@ -876,7 +883,7 @@ public ScalarOperator visitMatchExprOperator(MatchExprOperator operator, Void co } } - private static class CheckBlockingNode extends OptExpressionVisitor { + public static class CheckBlockingNode extends OptExpressionVisitor { private boolean visitChild(OptExpression optExpression, Void context) { if (optExpression.getInputs().size() != 1) { return false; diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java index d3975bf927b6b..d757e8f0dba4e 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java @@ -17,9 +17,11 @@ import com.starrocks.catalog.ColumnId; import com.starrocks.common.FeConstants; import com.starrocks.planner.OlapScanNode; +import com.starrocks.sql.optimizer.rule.tree.lowcardinality.DecodeCollector; import com.starrocks.sql.optimizer.statistics.IDictManager; import com.starrocks.thrift.TExplainLevel; import com.starrocks.utframe.StarRocksAssert; +import com.starrocks.utframe.UtFrameUtils; import mockit.Expectations; import org.junit.AfterClass; import org.junit.Assert; @@ -358,6 +360,43 @@ public void testDecodeNodeRewriteMultiAgg() } } + @Test + public void testIdentifyBlocking() throws Exception { + String sql = "select * from low_card_t1 order by 1"; + boolean hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + sql = "select sum(cpc) from low_card_t1"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + sql = "select sum(cpc) from low_card_t1 union select sum(cpc) from low_card_t2"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertFalse(hasBlockingNode); + sql = "(select sum(cpc) from low_card_t1) union (select sum(cpc) from low_card_t2) order by 1"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + sql = "select sum(ss) from (" + + "(select sum(cpc) as ss from low_card_1) union (select sum(cpc) as ss from low_card_t2)) x"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + sql = "select * from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertFalse(hasBlockingNode); + sql = "select * from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc order by 1"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + sql = "select sum(a.cpc) from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc order by 1"; + hasBlockingNode = new DecodeCollector.CheckBlockingNode().check( + UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan()); + Assert.assertTrue(hasBlockingNode); + } + @Test public void testDecodeNodeRewrite7() throws Exception { String sql = "select S_ADDRESS, count(S_ADDRESS) from supplier group by S_ADDRESS";