-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add decrease interface for aggregation #9737
base: master
Are you sure you want to change the base?
Changes from 28 commits
3f80eed
171fd46
943c578
e0b05a2
75d2f12
7022607
b011bfa
8628a2e
db4e399
a9360da
e3c31bd
e3b9756
a159857
0d4401d
9ff2bd3
a9aa879
bf652bf
61eaef0
6017780
2358efb
3227b50
b2772ea
37f4cc5
dd52c9f
6379b0f
9e41b3a
e6adcbf
b84c911
aa12221
2b9e991
c54f45f
ad8f002
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,12 @@ struct AggregateFunctionAvgData | |
T sum; | ||
UInt64 count; | ||
|
||
void reset() | ||
{ | ||
sum = T(0); | ||
count = 0; | ||
} | ||
|
||
AggregateFunctionAvgData() | ||
: sum(0) | ||
, count(0) | ||
|
@@ -67,6 +73,8 @@ class AggregateFunctionAvg final | |
return std::make_shared<DataTypeFloat64>(); | ||
} | ||
|
||
void prepareWindow(AggregateDataPtr __restrict) const override {} | ||
|
||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override | ||
{ | ||
if constexpr (IsDecimal<T>) | ||
|
@@ -78,6 +86,19 @@ class AggregateFunctionAvg final | |
++this->data(place).count; | ||
} | ||
|
||
void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override | ||
{ | ||
if constexpr (IsDecimal<T>) | ||
this->data(place).sum -= static_cast<const ColumnDecimal<T> &>(*columns[0]).getData()[row_num]; | ||
else | ||
{ | ||
this->data(place).sum -= static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]; | ||
} | ||
--this->data(place).count; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
} | ||
|
||
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } | ||
|
||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override | ||
{ | ||
this->data(place).sum += this->data(rhs).sum; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ class AggregateFunctionForEach final | |
AggregateFunctionForEachData & ensureAggregateData( | ||
AggregateDataPtr __restrict place, | ||
size_t new_size, | ||
Arena & arena) const | ||
Arena * arena) const | ||
{ | ||
AggregateFunctionForEachData & state = data(place); | ||
|
||
|
@@ -75,7 +75,10 @@ class AggregateFunctionForEach final | |
size_t old_size = state.dynamic_array_size; | ||
if (old_size < new_size) | ||
{ | ||
state.array_of_aggregate_datas = arena.realloc( | ||
if unlikely (arena == nullptr) | ||
throw Exception("Get nullptr in ensureAggregateData"); | ||
|
||
state.array_of_aggregate_datas = arena->realloc( | ||
state.array_of_aggregate_datas, | ||
old_size * nested_size_of_data, | ||
new_size * nested_size_of_data); | ||
|
@@ -149,13 +152,25 @@ class AggregateFunctionForEach final | |
bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); } | ||
|
||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override | ||
{ | ||
addOrDecrease<true>(place, columns, row_num, arena); | ||
} | ||
|
||
void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) | ||
const override | ||
{ | ||
addOrDecrease<false>(place, columns, row_num, arena); | ||
} | ||
|
||
template <bool is_add> | ||
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const | ||
{ | ||
const IColumn * nested[num_arguments]; | ||
|
||
for (size_t i = 0; i < num_arguments; ++i) | ||
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData(); | ||
|
||
const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]); | ||
const auto & first_array_column = static_cast<const ColumnArray &>(*columns[0]); | ||
const IColumn::Offsets & offsets = first_array_column.getOffsets(); | ||
|
||
size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; | ||
|
@@ -164,7 +179,7 @@ class AggregateFunctionForEach final | |
/// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. | ||
for (size_t i = 1; i < num_arguments; ++i) | ||
{ | ||
const ColumnArray & ith_column = static_cast<const ColumnArray &>(*columns[i]); | ||
const auto & ith_column = static_cast<const ColumnArray &>(*columns[i]); | ||
const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); | ||
|
||
if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) | ||
|
@@ -173,20 +188,45 @@ class AggregateFunctionForEach final | |
ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); | ||
} | ||
|
||
AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena); | ||
AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, arena); | ||
|
||
char * nested_state = state.array_of_aggregate_datas; | ||
for (size_t i = begin; i < end; ++i) | ||
{ | ||
nested_func->add(nested_state, nested, i, arena); | ||
if constexpr (is_add) | ||
nested_func->add(nested_state, nested, i, arena); | ||
else | ||
nested_func->decrease(nested_state, nested, i, arena); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not just call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Because |
||
nested_state += nested_size_of_data; | ||
} | ||
} | ||
|
||
void reset(AggregateDataPtr __restrict place) const override | ||
{ | ||
AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); | ||
char * nested_state = state.array_of_aggregate_datas; | ||
for (size_t i = 0; i < state.dynamic_array_size; i++) | ||
{ | ||
nested_func->reset(nested_state); | ||
nested_state += nested_size_of_data; | ||
} | ||
} | ||
|
||
void prepareWindow(AggregateDataPtr __restrict place) const override | ||
{ | ||
AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); | ||
char * nested_state = state.array_of_aggregate_datas; | ||
for (size_t i = 0; i < state.dynamic_array_size; i++) | ||
{ | ||
nested_func->prepareWindow(nested_state); | ||
nested_state += nested_size_of_data; | ||
} | ||
} | ||
|
||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override | ||
{ | ||
const AggregateFunctionForEachData & rhs_state = data(rhs); | ||
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena); | ||
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, arena); | ||
|
||
const char * rhs_nested_state = rhs_state.array_of_aggregate_datas; | ||
char * nested_state = state.array_of_aggregate_datas; | ||
|
@@ -220,7 +260,7 @@ class AggregateFunctionForEach final | |
size_t new_size = 0; | ||
readBinary(new_size, buf); | ||
|
||
ensureAggregateData(place, new_size, *arena); | ||
ensureAggregateData(place, new_size, arena); | ||
|
||
char * nested_state = state.array_of_aggregate_datas; | ||
for (size_t i = 0; i < new_size; ++i) | ||
|
@@ -234,7 +274,7 @@ class AggregateFunctionForEach final | |
{ | ||
const AggregateFunctionForEachData & state = data(place); | ||
|
||
ColumnArray & arr_to = static_cast<ColumnArray &>(to); | ||
auto & arr_to = static_cast<ColumnArray &>(to); | ||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); | ||
IColumn & elems_to = arr_to.getData(); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,32 +116,36 @@ class AggregateFunctionGroupConcat final | |
|
||
DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; } | ||
|
||
/// reject nulls before add() of nested agg | ||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override | ||
template <bool is_add> | ||
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const | ||
{ | ||
if constexpr (only_one_column) | ||
{ | ||
if (is_nullable[0]) | ||
{ | ||
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]); | ||
const auto * column = static_cast<const ColumnNullable *>(columns[0]); | ||
if (!column->isNullAt(row_num)) | ||
{ | ||
this->setFlag(place); | ||
const IColumn * nested_column = &column->getNestedColumn(); | ||
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); | ||
|
||
if constexpr (is_add) | ||
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); | ||
else | ||
this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Because |
||
} | ||
return; | ||
} | ||
} | ||
else | ||
{ | ||
/// remove the row with null, except for sort columns | ||
const ColumnTuple & tuple = static_cast<const ColumnTuple &>(*columns[0]); | ||
const auto & tuple = static_cast<const ColumnTuple &>(*columns[0]); | ||
for (size_t i = 0; i < number_of_concat_items; ++i) | ||
{ | ||
if (is_nullable[i]) | ||
{ | ||
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i)); | ||
const auto & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i)); | ||
if (nullable_col.isNullAt(row_num)) | ||
{ | ||
/// If at least one column has a null value in the current row, | ||
|
@@ -152,7 +156,21 @@ class AggregateFunctionGroupConcat final | |
} | ||
} | ||
this->setFlag(place); | ||
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); | ||
if constexpr (is_add) | ||
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); | ||
else | ||
this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); | ||
} | ||
|
||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override | ||
{ | ||
addOrDecrease<true>(place, columns, row_num, arena); | ||
} | ||
|
||
void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) | ||
const override | ||
{ | ||
addOrDecrease<false>(place, columns, row_num, arena); | ||
} | ||
|
||
void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
{ }
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okk