Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
dcm: optimization pass to evaluate constant sub-query expressions onl…
Browse files Browse the repository at this point in the history
…y once

Signed-off-by: Lalith Suresh <[email protected]>
  • Loading branch information
lalithsuresh committed Oct 26, 2020
1 parent 8071541 commit 7d59066
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@

package com.vmware.dcm.backend.ortools;

import com.vmware.dcm.compiler.monoid.VoidType;
import com.vmware.dcm.compiler.monoid.ColumnIdentifier;
import com.vmware.dcm.compiler.monoid.GroupByComprehension;
import com.vmware.dcm.compiler.monoid.MonoidComprehension;
import com.vmware.dcm.compiler.monoid.SimpleVisitor;
import com.vmware.dcm.compiler.monoid.VoidType;

import java.util.LinkedHashSet;


/**
* A visitor that returns the set of accessed columns within a comprehension's scope, *without entering
* A visitor that returns the set of accessed columns within a comprehension's scope, without entering
* sub-queries.
*/
class GetColumnIdentifiers extends SimpleVisitor {
private final LinkedHashSet<ColumnIdentifier> columnIdentifiers = new LinkedHashSet<>();
private final boolean visitInnerComprehensions;

GetColumnIdentifiers(final boolean visitInnerComprehensions) {
this.visitInnerComprehensions = visitInnerComprehensions;
}

@Override
protected VoidType visitColumnIdentifier(final ColumnIdentifier node, final VoidType context) {
Expand All @@ -30,11 +35,17 @@ protected VoidType visitColumnIdentifier(final ColumnIdentifier node, final Void

@Override
protected VoidType visitMonoidComprehension(final MonoidComprehension node, final VoidType context) {
if (visitInnerComprehensions) {
super.visitMonoidComprehension(node, context);
}
return defaultReturn();
}

@Override
protected VoidType visitGroupByComprehension(final GroupByComprehension node, final VoidType context) {
if (visitInnerComprehensions) {
super.visitGroupByComprehension(node, context);
}
return defaultReturn();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2018-2020 VMware, Inc. All Rights Reserved.
* SPDX-License-Identifier: BSD-2
*/

package com.vmware.dcm.backend.ortools;

import com.vmware.dcm.compiler.monoid.ColumnIdentifier;
import com.vmware.dcm.compiler.monoid.GroupByComprehension;
import com.vmware.dcm.compiler.monoid.MonoidComprehension;
import com.vmware.dcm.compiler.monoid.TableRowGenerator;

import java.util.LinkedHashSet;
import java.util.Set;
import java.util.stream.Collectors;

/*
* Evaluates whether a sub-query (and its inner sub-queries etc.) can be treated as a constant expression.
*/
public class IsConstantSubquery {

static boolean apply(final MonoidComprehension expr) {
final GetColumnIdentifiers visitor = new GetColumnIdentifiers(true);
if (expr instanceof GroupByComprehension) {
final MonoidComprehension comprehension = ((GroupByComprehension) expr).getComprehension();
comprehension.getHead().getSelectExprs().forEach(visitor::visit);
comprehension.getQualifiers().forEach(visitor::visit);
} else {
expr.getHead().getSelectExprs().forEach(visitor::visit);
expr.getQualifiers().forEach(visitor::visit);
}
final LinkedHashSet<ColumnIdentifier> columnIdentifiers = visitor.getColumnIdentifiers();

final Set<String> accessedTables = expr.getQualifiers()
.stream().filter(q -> q instanceof TableRowGenerator)
.map(e -> ((TableRowGenerator ) e).getTable().getAliasedName())
.collect(Collectors.toSet());
return columnIdentifiers.stream().allMatch(
ci -> !ci.getField().isControllable() && accessedTables.contains(ci.getTableName())
);
}
}
92 changes: 23 additions & 69 deletions dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -265,7 +263,7 @@ public List<String> generateModelCode(final IRContext context,
final OutputIR.Block outerBlock = outputIR.newBlock("outer");
translationContext.enterScope(outerBlock);
final OutputIR.Block block = addView(name, rewrittenComprehension, false, translationContext);
translationContext.leaveScope();
output.addCode(translationContext.leaveScope().toString());
output.addCode(block.toString());
});
constraintViews
Expand All @@ -277,7 +275,7 @@ public List<String> generateModelCode(final IRContext context,
final OutputIR.Block outerBlock = outputIR.newBlock("outer");
translationContext.enterScope(outerBlock);
final OutputIR.Block block = addView(name, rewrittenComprehension, true, translationContext);
translationContext.leaveScope();
output.addCode(translationContext.leaveScope().toString());
output.addCode(block.toString());
} else {
final OutputIR.Block outerBlock = outputIR.newBlock("outer");
Expand All @@ -297,7 +295,6 @@ public List<String> generateModelCode(final IRContext context,
final OutputIR.Block outerBlock = outputIR.newBlock("outer");
objFunctionContext.enterScope(outerBlock);
final String exprStr = exprToStr(rewrittenComprehension, objFunctionContext);
objFunctionContext.leaveScope();
output.addCode(outerBlock.toString());
output.addStatement("final $T $L = $L", IntVar.class, name, exprStr);
});
Expand Down Expand Up @@ -568,7 +565,7 @@ private List<ColumnIdentifier> getColumnsAccessed(final List<? extends Expr> exp
}

private LinkedHashSet<ColumnIdentifier> getColumnsAccessed(final Expr expr) {
final GetColumnIdentifiers visitor = new GetColumnIdentifiers();
final GetColumnIdentifiers visitor = new GetColumnIdentifiers(false);
visitor.visit(expr);
return visitor.getColumnIdentifiers();
}
Expand Down Expand Up @@ -1590,12 +1587,12 @@ protected String visitMonoidComprehension(final MonoidComprehension node, final
visitor.visit(headSelectItem);
final boolean headSelectItemContainsMonoidFunction = visitor.getFound();

final OutputIR.Block currentBlock = context.currentScope();
attemptConstantSubqueryOptimization(node, subQueryBlock, context);

final TranslationContext newCtx = context.withEnterFunctionContext();
// If the head contains a function, then this is a scalar subquery
if (headSelectItemContainsMonoidFunction) {
newCtx.enterScope(subQueryBlock);
currentBlock.addBody(subQueryBlock);
final String ret = apply(innerVisitor.visit(headSelectItem, newCtx), context);
newCtx.leaveScope();
return ret;
Expand All @@ -1606,7 +1603,6 @@ protected String visitMonoidComprehension(final MonoidComprehension node, final
final String type = inferType(node.getHead().getSelectExprs().get(0));
final String listName =
extractListFromLoop(processedHeadItem, subQueryBlock, newSubqueryName, type);
currentBlock.addBody(subQueryBlock);
newCtx.leaveScope();
return apply(listName, subQueryBlock, context);
}
Expand All @@ -1624,12 +1620,13 @@ protected String visitGroupByComprehension(final GroupByComprehension node, fina
Preconditions.checkArgument(node.getComprehension().getHead().getSelectExprs().size() == 1);
final Expr headSelectItem = node.getComprehension().getHead().getSelectExprs().get(0);

final OutputIR.Block currentBlock = context.currentScope();
final TranslationContext newCtx = context.withEnterFunctionContext();

attemptConstantSubqueryOptimization(node, subQueryBlock, context);

// if scalar subquery
if (headSelectItem instanceof MonoidFunction) {
newCtx.enterScope(subQueryBlock);
currentBlock.addBody(subQueryBlock);
final String ret = apply(innerVisitor.visit(headSelectItem, newCtx), context);
newCtx.leaveScope();
return ret;
Expand All @@ -1640,7 +1637,6 @@ protected String visitGroupByComprehension(final GroupByComprehension node, fina
final String type = inferType(node.getComprehension().getHead().getSelectExprs().get(0));
final String listName =
extractListFromLoop(processedHeadItem, subQueryBlock, newSubqueryName, type);
currentBlock.addBody(subQueryBlock);
newCtx.leaveScope();
// Treat as a vector
return apply(listName, subQueryBlock, context);
Expand Down Expand Up @@ -1740,6 +1736,20 @@ private String createTermsForScalarProduct(final Expr variables, final Expr coef
extractListFromLoop(coefficientsItem, outerBlock, forLoop, coefficientsType);
return CodeBlock.of("o.scalProd($L, $L)", listOfVariablesItem, listOfCoefficientsItem).toString();
}

/**
* Constant sub-queries can be floated to the root block so that we evaluate them only once
*/
private void attemptConstantSubqueryOptimization(final MonoidComprehension node,
final OutputIR.Block subQueryBlock,
final TranslationContext context) {
if (IsConstantSubquery.apply(node)) {
final OutputIR.Block rootBlock = context.getRootBlock();
rootBlock.addBody(subQueryBlock);
} else {
context.currentScope().addBody(subQueryBlock);
}
}
}

/**
Expand Down Expand Up @@ -1778,7 +1788,7 @@ private String extractListFromLoop(final String variableToExtract, final OutputI
* @return the name of the list being extracted
*/
@SuppressFBWarnings("UPM_UNCALLED_PRIVATE_METHOD") // false positive
private String extractListFromLoop(final String variableToExtract, final OutputIR.Block outerBlock,
private String extractListFromLoop(final String variableToExtract, final OutputIR.Block outerBlock,
final String loopBlockName, final String variableType) {
final OutputIR.Block forLoop = outerBlock.getForLoopByName(loopBlockName);
return extractListFromLoop(variableToExtract, outerBlock, forLoop, variableType);
Expand Down Expand Up @@ -1865,62 +1875,6 @@ private String getTempViewName() {
return "tmp" + intermediateViewCounter.getAndIncrement();
}

/**
* Represents context required for code generation. It maintains a stack of blocks
* in the IR, that is used to correctly scope variable declarations and accesses.
*/
private static class TranslationContext {
private final Deque<OutputIR.Block> scopeStack;
private final boolean isFunctionContext;

private TranslationContext(final Deque<OutputIR.Block> declarations, final boolean isFunctionContext) {
this.scopeStack = declarations;
this.isFunctionContext = isFunctionContext;
}

private TranslationContext(final boolean isFunctionContext) {
this(new ArrayDeque<>(), isFunctionContext);
}

TranslationContext withEnterFunctionContext() {
final Deque<OutputIR.Block> stackCopy = new ArrayDeque<>(scopeStack);
return new TranslationContext(stackCopy, true);
}

boolean isFunctionContext() {
return isFunctionContext;
}

void enterScope(final OutputIR.Block block) {
scopeStack.addLast(block);
}

OutputIR.Block currentScope() {
return Objects.requireNonNull(scopeStack.getLast());
}

OutputIR.Block leaveScope() {
return scopeStack.removeLast();
}

String declareVariable(final String expression) {
for (final OutputIR.Block block: scopeStack) {
if (block.hasDeclaration(expression)) {
return block.getDeclaredName(expression);
}
}
return scopeStack.getLast().declare(expression);
}

String declareVariable(final String expression, final OutputIR.Block block) {
return block.declare(expression);
}

String getTupleVarName() {
return currentScope().getTupleName();
}
}

private static CodeBlock statement(final String format, final Object... args) {
return CodeBlock.builder().addStatement(format, args).build();
}
Expand Down

0 comments on commit 7d59066

Please sign in to comment.