Skip to content

Commit

Permalink
Fixed NPE in OrderTask (fixes #199)
Browse files Browse the repository at this point in the history
  • Loading branch information
vania-pooh committed Dec 10, 2016
1 parent fa8e1c9 commit 25f3860
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ public Set<String> validateInput(List<Object> args) {
firstArg instanceof Integer ||
firstArg instanceof Long
)){
return Collections.singleton(String.format("Function argument should be a number but a %s is given", firstArg.getClass().getCanonicalName()));
return Collections.singleton(String.format(
"Function argument should be a number but a %s is given",
firstArg != null ? firstArg.getClass().getCanonicalName() : "NULL"
));
}
return Collections.emptySet();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ public class GroupTask implements Task {

private final List<Object> expressions = new ArrayList<>();

private final ExpressionEvaluator expressionEvaluator;

@Autowired
private ExpressionEvaluator expressionEvaluator;
public GroupTask(ExpressionEvaluator expressionEvaluator) {
this.expressionEvaluator = expressionEvaluator;
}

public void addExpression(Object expression) {
this.expressions.add(expression);
Expand Down Expand Up @@ -65,10 +69,16 @@ private Map<List<Object>, List<DataRow>> groupData(Map<List<Object>, List<DataRo
Map<List<Object>, List<DataRow>> newData = new HashMap<>();
previousData.keySet().forEach(k -> {
List<DataRow> currentKeyData = previousData.get(k);
Map<Object, List<DataRow>> groupedData = currentKeyData.stream()
.collect(Collectors.groupingBy( //Here we group by current expression
dr -> expressionEvaluator.evaluate(currentExpression, dr)
));
Map<Object, List<DataRow>> groupedData = new HashMap<>();
try {
groupedData.putAll(currentKeyData.stream()
.collect(Collectors.groupingBy( //Here we group by current expression
dr -> expressionEvaluator.evaluate(currentExpression, dr)
)));
} catch (NullPointerException e) {
//NPE occurs only when expression return NULL
throw new RuntimeException("Can not group by NULL column values");
}
groupedData.keySet().forEach(gk -> {
List<Object> newKey = new ArrayList<Object>(){
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;

import java.io.Serializable;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

@Component
Expand All @@ -25,9 +26,13 @@
public class OrderTask implements Task {

private final List<OrderExpression> expressions = new ArrayList<>();


private final ExpressionEvaluator expressionEvaluator;

@Autowired
private ExpressionEvaluator expressionEvaluator;
public OrderTask(ExpressionEvaluator expressionEvaluator) {
this.expressionEvaluator = expressionEvaluator;
}

public void addExpression(OrderExpression orderExpression) {
this.expressions.add(orderExpression);
Expand All @@ -36,15 +41,15 @@ public void addExpression(OrderExpression orderExpression) {
@Override
public ExecutionResult execute(ExecutionResult previousTaskResult) throws SQLException {
try {
Optional<Comparator<DataRow>> comparatorCandidate = createComparator(Optional.empty(), expressions);
if (comparatorCandidate.isPresent()) {
Comparator<DataRow> comparator = createComparator(null, expressions);
if (comparator != null) {
return new ExecutionResult(){
{
setCount(previousTaskResult.getCount());
DataContainer newData = new DataContainer(
previousTaskResult.getData(),
rows -> rows.stream()
.sorted(comparatorCandidate.get())
.sorted(comparator)
.collect(Collectors.toList())
);
setData(newData);
Expand All @@ -56,25 +61,45 @@ public ExecutionResult execute(ExecutionResult previousTaskResult) throws SQLExc
throw new SQLException(e);
}
}
private Optional<Comparator<DataRow>> createComparator(Optional<Comparator<DataRow>> comparator, List<OrderExpression> remainingExpressions) {

private Comparator<DataRow> createComparator(Comparator<DataRow> comparator, List<OrderExpression> remainingExpressions) {
if (remainingExpressions.isEmpty()) {
return comparator;
}
OrderExpression currentExpression = remainingExpressions.remove(0);
Comparator<DataRow> nextComparator = (!comparator.isPresent()) ?
Comparator<DataRow> nextComparator = comparator == null ?
getComparator(currentExpression) :
comparator.get().thenComparing(getComparator(currentExpression));
return createComparator(Optional.of(nextComparator), remainingExpressions);
comparator.thenComparing(getComparator(currentExpression));
return createComparator(nextComparator, remainingExpressions);
}

private Comparator<DataRow> getComparator(OrderExpression orderExpression) {
Comparator<DataRow> comparator = Comparator.comparing(dr -> expressionEvaluator.evaluate(orderExpression.getExpression(), dr));
Comparator<DataRow> comparator = getComparator(
dr -> expressionEvaluator.evaluate(orderExpression.getExpression(), dr)
);
return orderExpression.getOrderDirection() == OrderDirection.ASC ?
comparator :
comparator.reversed();
}

private <T, U extends Comparable<? super U>> Comparator<T> getComparator(
Function<? super T, ? extends U> keyExtractor
) {
//Currently we use nullsLast policy. We may want to change this
//if respective SQL expressions like NULLS FIRST are introduced.
return (Comparator<T> & Serializable)
(c1, c2) -> {
U left = keyExtractor.apply(c1);
U right = keyExtractor.apply(c2);
if (left == null) {
return (right == null) ? 0 : 1;
} else if (right == null) {
return -1;
}
return left.compareTo(right);
};
}

public List<OrderExpression> getExpressions() {
return new ArrayList<>(expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import java.sql.SQLException;
import java.util.*;
import java.util.stream.Collectors;

Expand All @@ -32,13 +33,7 @@ public class GroupTaskTest {

@Test
public void testExecute() throws Exception {
GroupTask groupTask = applicationContext.getBean(GroupTask.class);
groupTask.addExpression(new ColumnExpression(FIRST_COLUMN, TABLE_NAME));
groupTask.addExpression(new FunctionExpression(
FUNCTION_NAME,
Collections.singletonList(new ColumnExpression(SECOND_COLUMN, TABLE_NAME))
));
ExecutionResult output = groupTask.execute(createInput());
ExecutionResult output = groupBy(false);
assertThat(output.getCount(), equalTo(3));
List<DataRow> data = output.getData().getRows();
assertThat(data, hasSize(3));
Expand All @@ -51,8 +46,23 @@ public void testExecute() throws Exception {
createRow("two", 2)
));
}

private ExecutionResult createInput() {

private ExecutionResult groupBy(boolean withNullRow) throws Exception {
GroupTask groupTask = applicationContext.getBean(GroupTask.class);
groupTask.addExpression(new ColumnExpression(FIRST_COLUMN, TABLE_NAME));
groupTask.addExpression(new FunctionExpression(
FUNCTION_NAME,
Collections.singletonList(new ColumnExpression(SECOND_COLUMN, TABLE_NAME))
));
return groupTask.execute(createInput(withNullRow));
}

@Test(expected = SQLException.class)
public void testNullValue() throws Exception {
groupBy(true);
}

private ExecutionResult createInput(boolean withNullRow) {
ExecutionResult executionResult = new ExecutionResult();
Map<String, List<String>> columnsMap = new HashMap<String, List<String>>() {
{
Expand All @@ -64,6 +74,9 @@ private ExecutionResult createInput() {
dataContainer.addRow(createRow("two", 1));
dataContainer.addRow(createRow("two", 2));
dataContainer.addRow(createRow("two", -2));
if (withNullRow) {
dataContainer.addRow(createRow(null, -3));
}
executionResult.setData(dataContainer);
return executionResult;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ private static ExecutionResult getInput() {
dataContainer.addRow(createRow("b", 2));
dataContainer.addRow(createRow("a", 4));
dataContainer.addRow(createRow("a", 3));
dataContainer.addRow(createRow(null, 3));
input.setData(dataContainer);
return input;
}
Expand Down

0 comments on commit 25f3860

Please sign in to comment.