Skip to content

Commit

Permalink
Skip PreAggregateCaseAggregations it aggregations are not reduced
Browse files Browse the repository at this point in the history
If the number of pre-aggregations is equal to number of case aggregations,
then there is no performance gain in rule execution. Additionally, firing rule
in such case could lead to infinite rule execution loop, since new pre-aggregations
could be eligable for further optimization.
  • Loading branch information
sopel39 authored and martint committed Jul 25, 2024
1 parent 849d995 commit 401669e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
}

Map<PreAggregationKey, PreAggregation> preAggregations = getPreAggregations(aggregations, context);
if (preAggregations.size() == aggregations.size()) {
// Prevent rule execution if number of pre-aggregations is equal to number of case aggregations.
// In such case there is no gain in performance, and it could lead to infinite rule execution loop.
return Result.empty();
}

Assignments.Builder preGroupingExpressionsBuilder = Assignments.builder();
preGroupingExpressionsBuilder.putIdentities(extraGroupingKeys);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ public void testPreAggregatesSumAggregationsWithZeroDefault()
"SELECT " +
"col_varchar, " +
"sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), " +
"sum(CASE WHEN col_bigint = 2 THEN col_bigint ELSE BIGINT '0' END), " +
"sum(CASE WHEN col_bigint = 2 THEN col_tinyint ELSE TINYINT '0' END), " +
"sum(CASE WHEN col_bigint = 3 THEN col_double ELSE DOUBLE '0' END), " +
"sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE DECIMAL '0.0' END), " +
Expand Down Expand Up @@ -424,6 +425,19 @@ public void testDoesNotFireForNonCaseAggregation()
"GROUP BY col_varchar");
}

@Test
public void testDoesNotFireIfAggregationsAreNotReduced()
{
assertThatDoesNotFire("""
SELECT
SUM(IF(col_varchar != 'V', col_bigint + col_decimal)),
SUM(IF(col_varchar != 'V', col_decimal + col_tinyint)),
SUM(IF(col_varchar != 'V', col_tinyint + col_double)),
SUM(IF(col_varchar != 'V', col_double + col_bigint))
FROM t
""");
}

private void assertFires(@Language("SQL") String query)
{
assertThat(countOfMatchingNodes(plan(query), AggregationNode.class::isInstance)).isEqualTo(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,39 @@ public void testPreAggregate()
"SELECT " +
"key, " +
"sum(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 2 THEN value END), " +
"min(CASE WHEN sequence = 1 THEN value ELSE null END), " +
"max(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 1 THEN cast(value * 2 as real) ELSE cast(0 as real) END) " +
"FROM test_table " +
"GROUP BY key",
"VALUES ('a', 1, 2, 1, 10), ('b', 21, 13, 11, 54)",
"VALUES ('a', 1, null, 2, 1, 10), ('b', 21, null, 13, 11, 54)",
plan -> assertAggregationNodeCount(plan, 4));

assertQuery(
memorySession,
"SELECT " +
"sum(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 2 THEN value END), " +
"min(CASE WHEN sequence = 1 THEN value ELSE null END), " +
"max(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) " +
"FROM test_table",
"VALUES (22, 2, 11, 64)",
"VALUES (22, null, 2, 11, 64)",
plan -> assertAggregationNodeCount(plan, 4));

assertQuery(
memorySession,
"SELECT " +
"key, " +
"sum(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 2 THEN value END), " +
"min(CASE WHEN sequence = 1 THEN value ELSE null END), " +
"max(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 1 END) " +
"FROM test_table " +
"GROUP BY key",
"VALUES ('a', 1, 2, 1, 12), ('b', 21, 13, 11, 56)",
"VALUES ('a', 1, null, 2, 1, 12), ('b', 21, null, 13, 11, 56)",
plan -> assertAggregationNodeCount(plan, 2));

// non null default value on max aggregation
Expand All @@ -108,12 +111,13 @@ public void testPreAggregate()
"SELECT " +
"key, " +
"sum(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 2 THEN value END), " +
"min(CASE WHEN sequence = 1 THEN value ELSE null END), " +
"max(CASE WHEN sequence = 0 THEN value END), " +
"max(CASE WHEN sequence = 1 THEN value * 2 ELSE 100 END) " +
"FROM test_table " +
"GROUP BY key",
"VALUES ('a', 1, 2, 1, 100), ('b', 21, 13, 11, 100)",
"VALUES ('a', 1, null, 2, 1, 100), ('b', 21, null, 13, 11, 100)",
plan -> assertAggregationNodeCount(plan, 2));

// no rows matching sequence number
Expand Down Expand Up @@ -149,12 +153,13 @@ public void testPreAggregateWithFilter()
memorySession,
"SELECT " +
"sum(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 2 THEN value END), " +
"min(CASE WHEN sequence = 1 THEN value ELSE null END), " +
"max(CASE WHEN sequence = 0 THEN value END), " +
"sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) " +
"FROM test_table " +
"WHERE sequence = 42",
"VALUES (null, null, null, null)",
"VALUES (null, null, null, null, null)",
plan -> assertAggregationNodeCount(plan, 4));
}

Expand Down

0 comments on commit 401669e

Please sign in to comment.