Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
Signed-off-by: zombee0 <[email protected]>
  • Loading branch information
zombee0 committed Jan 10, 2025
1 parent 4b3ddc2 commit 1b1325f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.optimizer.rule.tree.lowcardinality;

import autovalue.shaded.com.google.common.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -59,6 +60,7 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.TestOnly;

import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -132,14 +134,14 @@ public class DecodeCollector extends OptExpressionVisitor<DecodeInfo, DecodeInfo
private final ColumnRefSet physicalOlapScanColumns = new ColumnRefSet();

// check if there is a blocking node in plan
private boolean blockingOutput = false;
private boolean canBlockingOutput = false;

public DecodeCollector(SessionVariable session) {
this.sessionVariable = session;
}

public void collect(OptExpression root, DecodeContext context) {
blockingOutput = new CheckBlockingNode().check(root);
canBlockingOutput = new CheckBlockingNode().check(root);
collectImpl(root, null);
initContext(context);
}
Expand Down Expand Up @@ -682,6 +684,11 @@ private static boolean supportLowCardinality(Type type) {
return type.isVarchar() || (type.isArrayType() && ((ArrayType) type).getItemType().isVarchar());
}

@TestOnly
public boolean canBlockingOutput() {
return canBlockingOutput;
}

// Check if an expression can be optimized using a dictionary
// If the expression only contains a string column, the expression can be optimized using a dictionary
private static class DictExpressionCollector extends ScalarOperatorVisitor<ScalarOperator, Void> {
Expand Down Expand Up @@ -876,7 +883,7 @@ public ScalarOperator visitMatchExprOperator(MatchExprOperator operator, Void co
}
}

private static class CheckBlockingNode extends OptExpressionVisitor<Boolean, Void> {
public static class CheckBlockingNode extends OptExpressionVisitor<Boolean, Void> {
private boolean visitChild(OptExpression optExpression, Void context) {
if (optExpression.getInputs().size() != 1) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import com.starrocks.catalog.ColumnId;
import com.starrocks.common.FeConstants;
import com.starrocks.planner.OlapScanNode;
import com.starrocks.sql.optimizer.rule.tree.lowcardinality.DecodeCollector;
import com.starrocks.sql.optimizer.statistics.IDictManager;
import com.starrocks.thrift.TExplainLevel;
import com.starrocks.utframe.StarRocksAssert;
import com.starrocks.utframe.UtFrameUtils;
import mockit.Expectations;
import org.junit.AfterClass;
import org.junit.Assert;
Expand Down Expand Up @@ -358,6 +360,43 @@ public void testDecodeNodeRewriteMultiAgg()
}
}

@Test
public void testIdentifyBlocking() throws Exception {
String sql = "select * from low_card_t1 order by 1";
boolean hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
sql = "select sum(cpc) from low_card_t1";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
sql = "select sum(cpc) from low_card_t1 union select sum(cpc) from low_card_t2";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertFalse(hasBlockingNode);
sql = "(select sum(cpc) from low_card_t1) union (select sum(cpc) from low_card_t2) order by 1";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
sql = "select sum(ss) from (" +
"(select sum(cpc) as ss from low_card_1) union (select sum(cpc) as ss from low_card_t2)) x";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
sql = "select * from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertFalse(hasBlockingNode);
sql = "select * from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc order by 1";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
sql = "select sum(a.cpc) from low_card_t1 a join low_card_t2 b on a.cpc = b.cpc order by 1";
hasBlockingNode = new DecodeCollector.CheckBlockingNode().check(
UtFrameUtils.getPlanAndFragment(connectContext, sql).second.getPhysicalPlan());
Assert.assertTrue(hasBlockingNode);
}

@Test
public void testDecodeNodeRewrite7() throws Exception {
String sql = "select S_ADDRESS, count(S_ADDRESS) from supplier group by S_ADDRESS";
Expand Down

0 comments on commit 1b1325f

Please sign in to comment.