Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decrease interface for aggregation #9737

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3f80eed
init
xzhangxian1008 Nov 26, 2024
171fd46
add gtest
xzhangxian1008 Nov 27, 2024
943c578
add ft
xzhangxian1008 Nov 27, 2024
e0b05a2
refine test framework
xzhangxian1008 Nov 27, 2024
75d2f12
fix compilation phase
xzhangxian1008 Nov 27, 2024
7022607
init
xzhangxian1008 Dec 4, 2024
b011bfa
codes done, need tests
xzhangxian1008 Dec 5, 2024
8628a2e
save
xzhangxian1008 Dec 12, 2024
db4e399
add tests
xzhangxian1008 Dec 13, 2024
a9360da
add sum tests
xzhangxian1008 Dec 17, 2024
e3c31bd
refine tests
xzhangxian1008 Dec 17, 2024
e3b9756
format
xzhangxian1008 Dec 17, 2024
a159857
fix bugs
xzhangxian1008 Dec 18, 2024
0d4401d
refine test
xzhangxian1008 Dec 18, 2024
9ff2bd3
fix tests
xzhangxian1008 Dec 18, 2024
a9aa879
add test for string type
xzhangxian1008 Dec 19, 2024
bf652bf
add test for SingleValueDataGeneric type
xzhangxian1008 Dec 19, 2024
61eaef0
tweaking
xzhangxian1008 Dec 20, 2024
6017780
Merge branch 'master' into wagg
xzhangxian1008 Dec 20, 2024
2358efb
remove something
xzhangxian1008 Dec 20, 2024
3227b50
revoke
xzhangxian1008 Dec 20, 2024
b2772ea
add AlignedBuffer
xzhangxian1008 Dec 20, 2024
37f4cc5
remove something
xzhangxian1008 Dec 20, 2024
dd52c9f
tweaking
xzhangxian1008 Dec 20, 2024
6379b0f
remove useless change
xzhangxian1008 Dec 20, 2024
9e41b3a
fix ci
xzhangxian1008 Dec 23, 2024
e6adcbf
fix ci
xzhangxian1008 Dec 23, 2024
b84c911
fix ut
xzhangxian1008 Dec 24, 2024
aa12221
address some comments
xzhangxian1008 Dec 25, 2024
2b9e991
create new class for window
xzhangxian1008 Dec 26, 2024
c54f45f
add ut tests
xzhangxian1008 Dec 27, 2024
ad8f002
add static
xzhangxian1008 Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions dbms/src/AggregateFunctions/AggregateFunctionArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,25 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
size_t alignOfData() const override { return nested_func->alignOfData(); }

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
addOrDecrease<true>(place, columns, row_num, arena);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena)
const override
{
addOrDecrease<false>(place, columns, row_num, arena);
}

template <bool is_add>
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const
{
const IColumn * nested[num_arguments];

for (size_t i = 0; i < num_arguments; ++i)
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData();

const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const auto & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets = first_array_column.getOffsets();

size_t begin = row_num == 0 ? 0 : offsets[row_num - 1];
Expand All @@ -82,7 +94,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
/// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
for (size_t i = 1; i < num_arguments; ++i)
{
const ColumnArray & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const auto & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const IColumn::Offsets & ith_offsets = ith_column.getOffsets();

if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
Expand All @@ -92,9 +104,16 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
}

for (size_t i = begin; i < end; ++i)
nested_func->add(place, nested, i, arena);
if constexpr (is_add)
nested_func->add(place, nested, i, arena);
else
nested_func->decrease(place, nested, i, arena);
}

void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); }

void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
nested_func->merge(place, rhs, arena);
Expand Down
21 changes: 21 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionAvg.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ struct AggregateFunctionAvgData
T sum;
UInt64 count;

void reset()
{
sum = T(0);
count = 0;
}

AggregateFunctionAvgData()
: sum(0)
, count(0)
Expand Down Expand Up @@ -67,6 +73,8 @@ class AggregateFunctionAvg final
return std::make_shared<DataTypeFloat64>();
}

void prepareWindow(AggregateDataPtr __restrict) const override {}

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (IsDecimal<T>)
Expand All @@ -78,6 +86,19 @@ class AggregateFunctionAvg final
++this->data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (IsDecimal<T>)
this->data(place).sum -= static_cast<const ColumnDecimal<T> &>(*columns[0]).getData()[row_num];
else
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove { }?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove { }?

okk

this->data(place).sum -= static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
}
--this->data(place).count;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add assert(this->data(place).count >= 0)?

}

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).sum += this->data(rhs).sum;
Expand Down
33 changes: 33 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionCount.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace DB
struct AggregateFunctionCountData
{
UInt64 count = 0;

inline void reset() noexcept { count = 0; }
};

namespace ErrorCodes
Expand All @@ -52,6 +54,15 @@ class AggregateFunctionCount final
++data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override
{
--data(place).count;
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void prepareWindow(AggregateDataPtr __restrict) const override {}

void addBatchSinglePlace(
size_t start_offset,
size_t batch_size,
Expand Down Expand Up @@ -173,6 +184,15 @@ class AggregateFunctionCountNotNullUnary final
data(place).count += !static_cast<const ColumnNullable &>(*columns[0]).isNullAt(row_num);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
data(place).count -= !static_cast<const ColumnNullable &>(*columns[0]).isNullAt(row_num);
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void prepareWindow(AggregateDataPtr __restrict) const override {}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;
Expand Down Expand Up @@ -234,6 +254,19 @@ class AggregateFunctionCountNotNullVariadic final
++data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
for (size_t i = 0; i < number_of_arguments; ++i)
if (is_nullable[i] && static_cast<const ColumnNullable &>(*columns[i]).isNullAt(row_num))
return;

--data(place).count;
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void prepareWindow(AggregateDataPtr __restrict) const override {}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;
Expand Down
58 changes: 49 additions & 9 deletions dbms/src/AggregateFunctions/AggregateFunctionForEach.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class AggregateFunctionForEach final
AggregateFunctionForEachData & ensureAggregateData(
AggregateDataPtr __restrict place,
size_t new_size,
Arena & arena) const
Arena * arena) const
{
AggregateFunctionForEachData & state = data(place);

Expand All @@ -75,7 +75,10 @@ class AggregateFunctionForEach final
size_t old_size = state.dynamic_array_size;
if (old_size < new_size)
{
state.array_of_aggregate_datas = arena.realloc(
if unlikely (arena == nullptr)
throw Exception("Get nullptr in ensureAggregateData");

state.array_of_aggregate_datas = arena->realloc(
state.array_of_aggregate_datas,
old_size * nested_size_of_data,
new_size * nested_size_of_data);
Expand Down Expand Up @@ -149,13 +152,25 @@ class AggregateFunctionForEach final
bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); }

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
addOrDecrease<true>(place, columns, row_num, arena);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena)
const override
{
addOrDecrease<false>(place, columns, row_num, arena);
}

template <bool is_add>
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const
{
const IColumn * nested[num_arguments];

for (size_t i = 0; i < num_arguments; ++i)
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData();

const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const auto & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets = first_array_column.getOffsets();

size_t begin = row_num == 0 ? 0 : offsets[row_num - 1];
Expand All @@ -164,7 +179,7 @@ class AggregateFunctionForEach final
/// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
for (size_t i = 1; i < num_arguments; ++i)
{
const ColumnArray & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const auto & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const IColumn::Offsets & ith_offsets = ith_column.getOffsets();

if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
Expand All @@ -173,20 +188,45 @@ class AggregateFunctionForEach final
ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}

AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena);
AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, arena);

char * nested_state = state.array_of_aggregate_datas;
for (size_t i = begin; i < end; ++i)
{
nested_func->add(nested_state, nested, i, arena);
if constexpr (is_add)
nested_func->add(nested_state, nested, i, arena);
else
nested_func->decrease(nested_state, nested, i, arena);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just call nested_func->addOrDecrease<is_add>()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just call nested_func->addOrDecrease<is_add>()?

Because addOrDecrease<>() is not an interface in IAggregateFunction.

nested_state += nested_size_of_data;
}
}

void reset(AggregateDataPtr __restrict place) const override
{
AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr);
char * nested_state = state.array_of_aggregate_datas;
for (size_t i = 0; i < state.dynamic_array_size; i++)
{
nested_func->reset(nested_state);
nested_state += nested_size_of_data;
}
}

void prepareWindow(AggregateDataPtr __restrict place) const override
{
AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr);
char * nested_state = state.array_of_aggregate_datas;
for (size_t i = 0; i < state.dynamic_array_size; i++)
{
nested_func->prepareWindow(nested_state);
nested_state += nested_size_of_data;
}
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
const AggregateFunctionForEachData & rhs_state = data(rhs);
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena);
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, arena);

const char * rhs_nested_state = rhs_state.array_of_aggregate_datas;
char * nested_state = state.array_of_aggregate_datas;
Expand Down Expand Up @@ -220,7 +260,7 @@ class AggregateFunctionForEach final
size_t new_size = 0;
readBinary(new_size, buf);

ensureAggregateData(place, new_size, *arena);
ensureAggregateData(place, new_size, arena);

char * nested_state = state.array_of_aggregate_datas;
for (size_t i = 0; i < new_size; ++i)
Expand All @@ -234,7 +274,7 @@ class AggregateFunctionForEach final
{
const AggregateFunctionForEachData & state = data(place);

ColumnArray & arr_to = static_cast<ColumnArray &>(to);
auto & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
IColumn & elems_to = arr_to.getData();

Expand Down
32 changes: 25 additions & 7 deletions dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,32 +116,36 @@ class AggregateFunctionGroupConcat final

DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; }

/// reject nulls before add() of nested agg
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
template <bool is_add>
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const
{
if constexpr (only_one_column)
{
if (is_nullable[0])
{
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
const auto * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
{
this->setFlag(place);
const IColumn * nested_column = &column->getNestedColumn();
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);

if constexpr (is_add)
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
else
this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Because addOrDecrease<>() is not an interface in IAggregateFunction.

}
return;
}
}
else
{
/// remove the row with null, except for sort columns
const ColumnTuple & tuple = static_cast<const ColumnTuple &>(*columns[0]);
const auto & tuple = static_cast<const ColumnTuple &>(*columns[0]);
for (size_t i = 0; i < number_of_concat_items; ++i)
{
if (is_nullable[i])
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i));
const auto & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i));
if (nullable_col.isNullAt(row_num))
{
/// If at least one column has a null value in the current row,
Expand All @@ -152,7 +156,21 @@ class AggregateFunctionGroupConcat final
}
}
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
if constexpr (is_add)
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
else
this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena);
}

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
addOrDecrease<true>(place, columns, row_num, arena);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena)
const override
{
addOrDecrease<false>(place, columns, row_num, arena);
}

void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
Expand Down
11 changes: 11 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionIf.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ class AggregateFunctionIf final : public IAggregateFunctionHelper<AggregateFunct
nested_func->add(place, columns, row_num, arena);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena)
const override
{
if (static_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num])
nested_func->decrease(place, columns, row_num, arena);
}

void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); }

void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); }

void addBatch(
size_t start_offset,
size_t batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ class AggregateFunctionIntersectionsMax final
PointType right = static_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];

if (!isNaN(left))
this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
this->data(place).value.push_back(std::make_pair(left, static_cast<Int64>(1)), arena);

if (!isNaN(right))
this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
this->data(place).value.push_back(std::make_pair(right, static_cast<Int64>(-1)), arena);
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
Expand Down
Loading