diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index eacd499f86d..144bbfdc378 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -67,13 +67,25 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperalignOfData(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { const IColumn * nested[num_arguments]; for (size_t i = 0; i < num_arguments; ++i) nested[i] = &static_cast(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -82,7 +94,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -92,9 +104,16 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperadd(place, nested, i, arena); + if constexpr (is_add) + nested_func->add(place, nested, i, arena); + else + nested_func->decrease(place, nested, i, arena); } + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index 1879c3bca4a..b528f9f10f1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -29,6 +29,12 @@ struct AggregateFunctionAvgData T sum; UInt64 count; + void reset() + { + sum = T(0); + count = 0; + } + AggregateFunctionAvgData() : sum(0) , count(0) @@ -67,6 +73,8 @@ class AggregateFunctionAvg final return std::make_shared(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (IsDecimal) @@ -78,6 +86,19 @@ class AggregateFunctionAvg final ++this->data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + if constexpr (IsDecimal) + this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; + else + this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; + + --this->data(place).count; + assert(this->data(place).count >= 0); + } + + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).sum += this->data(rhs).sum; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index c8cdc4d4b27..c7ab49cb76e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -29,6 +29,8 @@ namespace DB struct AggregateFunctionCountData { UInt64 count = 0; + + inline void reset() noexcept { count = 0; } }; namespace ErrorCodes @@ -52,6 +54,15 @@ class AggregateFunctionCount final ++data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override + { + --data(place).count; + } + + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + + void prepareWindow(AggregateDataPtr __restrict) const override {} + void addBatchSinglePlace( size_t start_offset, size_t batch_size, @@ -173,6 +184,15 @@ class AggregateFunctionCountNotNullUnary final data(place).count += !static_cast(*columns[0]).isNullAt(row_num); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + data(place).count -= !static_cast(*columns[0]).isNullAt(row_num); + } + + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + + void prepareWindow(AggregateDataPtr __restrict) const override {} + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; @@ -234,6 +254,19 @@ class AggregateFunctionCountNotNullVariadic final ++data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + for (size_t i = 0; i < number_of_arguments; ++i) + if (is_nullable[i] && static_cast(*columns[i]).isNullAt(row_num)) + return; + + --data(place).count; + } + + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + + void prepareWindow(AggregateDataPtr __restrict) const override {} + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index a62db765523..2c8624d8e90 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -107,6 +107,26 @@ AggregateFunctionPtr AggregateFunctionFactory::get( return res; } +AggregateFunctionPtr AggregateFunctionFactory::getForWindow( + const Context & context, + const String & name, + const DataTypes & argument_types, + const Array & parameters, + int recursion_level) const +{ + AggregateFunctionCombinatorPtr combinator + = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("NullForWindow"); + if (!combinator) + throw Exception( + "Logical error: cannot find aggregate function combinator to apply a function to Nullable for window " + "arguments.", + ErrorCodes::LOGICAL_ERROR); + + DataTypes nested_types = combinator->transformArguments(argument_types); + AggregateFunctionPtr nested_function = getImpl(context, name, nested_types, parameters, recursion_level); + return combinator->transformAggregateFunction(nested_function, argument_types, parameters); +} + AggregateFunctionPtr AggregateFunctionFactory::getImpl( const Context & context, const String & name, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h index 18d48296dee..756decb0692 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h @@ -67,6 +67,14 @@ class AggregateFunctionFactory final : public ext::Singletonrealloc( state.array_of_aggregate_datas, old_size * nested_size_of_data, new_size * nested_size_of_data); @@ -149,13 +152,25 @@ class AggregateFunctionForEach final bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { const IColumn * nested[num_arguments]; for (size_t i = 0; i < num_arguments; ++i) nested[i] = &static_cast(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -164,7 +179,7 @@ class AggregateFunctionForEach final /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. for (size_t i = 1; i < num_arguments; ++i) { - const ColumnArray & ith_column = static_cast(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -173,12 +188,37 @@ class AggregateFunctionForEach final ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } - AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena); + AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, arena); char * nested_state = state.array_of_aggregate_datas; for (size_t i = begin; i < end; ++i) { - nested_func->add(nested_state, nested, i, arena); + if constexpr (is_add) + nested_func->add(nested_state, nested, i, arena); + else + nested_func->decrease(nested_state, nested, i, arena); + nested_state += nested_size_of_data; + } + } + + void reset(AggregateDataPtr __restrict place) const override + { + AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); + char * nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; i++) + { + nested_func->reset(nested_state); + nested_state += nested_size_of_data; + } + } + + void prepareWindow(AggregateDataPtr __restrict place) const override + { + AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); + char * nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; i++) + { + nested_func->prepareWindow(nested_state); nested_state += nested_size_of_data; } } @@ -186,7 +226,7 @@ class AggregateFunctionForEach final void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { const AggregateFunctionForEachData & rhs_state = data(rhs); - AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena); + AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, arena); const char * rhs_nested_state = rhs_state.array_of_aggregate_datas; char * nested_state = state.array_of_aggregate_datas; @@ -220,7 +260,7 @@ class AggregateFunctionForEach final size_t new_size = 0; readBinary(new_size, buf); - ensureAggregateData(place, new_size, *arena); + ensureAggregateData(place, new_size, arena); char * nested_state = state.array_of_aggregate_datas; for (size_t i = 0; i < new_size; ++i) @@ -234,7 +274,7 @@ class AggregateFunctionForEach final { const AggregateFunctionForEachData & state = data(place); - ColumnArray & arr_to = static_cast(to); + auto & arr_to = static_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & elems_to = arr_to.getData(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h index baee40e62d4..29a32b58cc5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h @@ -116,19 +116,23 @@ class AggregateFunctionGroupConcat final DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; } - /// reject nulls before add() of nested agg - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (only_one_column) { if (is_nullable[0]) { - const ColumnNullable * column = static_cast(columns[0]); + const auto * column = static_cast(columns[0]); if (!column->isNullAt(row_num)) { this->setFlag(place); const IColumn * nested_column = &column->getNestedColumn(); - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); } return; } @@ -136,12 +140,12 @@ class AggregateFunctionGroupConcat final else { /// remove the row with null, except for sort columns - const ColumnTuple & tuple = static_cast(*columns[0]); + const auto & tuple = static_cast(*columns[0]); for (size_t i = 0; i < number_of_concat_items; ++i) { if (is_nullable[i]) { - const ColumnNullable & nullable_col = static_cast(tuple.getColumn(i)); + const auto & nullable_col = static_cast(tuple.getColumn(i)); if (nullable_col.isNullAt(row_num)) { /// If at least one column has a null value in the current row, @@ -152,7 +156,21 @@ class AggregateFunctionGroupConcat final } } this->setFlag(place); - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); } void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index fbd5bd242d8..7d3dec9b6c2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -76,6 +76,17 @@ class AggregateFunctionIf final : public IAggregateFunctionHelperadd(place, columns, row_num, arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + if (static_cast(*columns[num_arguments - 1]).getData()[row_num]) + nested_func->decrease(place, columns, row_num, arena); + } + + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void addBatch( size_t start_offset, size_t batch_size, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h index ad029472f2c..5455d6674e3 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h @@ -103,10 +103,10 @@ class AggregateFunctionIntersectionsMax final PointType right = static_cast &>(*columns[1]).getData()[row_num]; if (!isNaN(left)) - this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena); + this->data(place).value.push_back(std::make_pair(left, static_cast(1)), arena); if (!isNaN(right)) - this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena); + this->data(place).value.push_back(std::make_pair(right, static_cast(-1)), arena); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 710443509dc..224ed534aff 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -37,7 +37,7 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelper(&argument); + const auto * data_type = typeid_cast(&argument); if (!data_type || data_type->getFunctionName() != nested_func->getName()) throw Exception( diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 2f41c931b91..73cb4ec7d84 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -24,19 +24,24 @@ #include #include - namespace DB { /** Aggregate functions that store one of passed values. * For example: min, max, any, anyLast. */ +struct CommonImpl +{ + static void decrease() { throw Exception("Not implemented yet"); } + static void reset() { throw Exception("Not implemented yet"); } + static void prepareWindow() { throw Exception("Not implemented yet"); } +}; /// For numeric values. template -struct SingleValueDataFixed +struct SingleValueDataFixed : public CommonImpl { -private: +protected: using Self = SingleValueDataFixed; bool has_value @@ -72,7 +77,6 @@ struct SingleValueDataFixed readBinary(value, buf); } - void change(const IColumn & column, size_t row_num, Arena *) { has_value = true; @@ -181,9 +185,9 @@ struct SingleValueDataFixed /** For strings. Short strings are stored in the object itself, and long strings are allocated separately. * NOTE It could also be suitable for arrays of numbers. */ -struct SingleValueDataString +struct SingleValueDataString : public CommonImpl { -private: +protected: using Self = SingleValueDataString; Int32 size = -1; /// -1 indicates that there is no value. @@ -215,9 +219,9 @@ struct SingleValueDataString public: static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; static constexpr Int32 MAX_SMALL_STRING_SIZE - = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(collator); + = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr); -private: +protected: char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. public: @@ -425,9 +429,9 @@ static_assert( /// For any other value types. -struct SingleValueDataGeneric +struct SingleValueDataGeneric : public CommonImpl { -private: +protected: using Self = SingleValueDataGeneric; Field value; @@ -519,6 +523,7 @@ struct SingleValueDataGeneric { Field new_value; column.get(row_num, new_value); + if (new_value < value) { value = new_value; @@ -551,6 +556,7 @@ struct SingleValueDataGeneric { Field new_value; column.get(row_num, new_value); + if (new_value > value) { value = new_value; @@ -594,6 +600,10 @@ struct AggregateFunctionMinData : Data } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } + void prepareWindow() { throw Exception("Not implemented yet"); } + + void insertResultInto(IColumn & to) const { Data::insertResultInto(to); } + static const char * name() { return "min"; } }; @@ -602,10 +612,13 @@ struct AggregateFunctionMaxData : Data { using Self = AggregateFunctionMaxData; + void insertResultInto(IColumn & to) const { Data::insertResultInto(to); } + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) { return this->changeIfGreater(column, row_num, arena); } + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } static const char * name() { return "max"; } @@ -732,7 +745,9 @@ class AggregateFunctionsSingleValue final explicit AggregateFunctionsSingleValue(const DataTypePtr & type) : type(type) { - if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max")) + if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max") + || StringRef(Data::name()) == StringRef("max_for_window") + || StringRef(Data::name()) == StringRef("min_for_window")) { if (!type->isComparable()) throw Exception( @@ -751,6 +766,15 @@ class AggregateFunctionsSingleValue final this->data(place).changeIfBetter(*columns[0], row_num, arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override + { + this->data(place).decrease(); + } + + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + + void prepareWindow(AggregateDataPtr __restrict place) const override { this->data(place).prepareWindow(); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { this->data(place).changeIfBetter(this->data(rhs), arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp new file mode 100644 index 00000000000..d43d8afc38f --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp @@ -0,0 +1,62 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB +{ +namespace +{ +AggregateFunctionPtr createAggregateFunctionMinForWindow( + const Context & /* context not used */, + const std::string & name, + const DataTypes & argument_types, + const Array & parameters) +{ + return AggregateFunctionPtr( + createAggregateFunctionSingleValueForWindow( + name, + argument_types, + parameters)); +} + +AggregateFunctionPtr createAggregateFunctionMaxForWindow( + const Context & /* context not used */, + const std::string & name, + const DataTypes & argument_types, + const Array & parameters) +{ + return AggregateFunctionPtr( + createAggregateFunctionSingleValueForWindow( + name, + argument_types, + parameters)); +} +} // namespace + +void registerAggregateFunctionsMinMaxForWindow(AggregateFunctionFactory & factory) +{ + factory.registerFunction( + "min_for_window", + createAggregateFunctionMinForWindow, + AggregateFunctionFactory::CaseInsensitive); + factory.registerFunction( + "max_for_window", + createAggregateFunctionMaxForWindow, + AggregateFunctionFactory::CaseInsensitive); +} +} // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h new file mode 100644 index 00000000000..831d9874135 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h @@ -0,0 +1,393 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +template +struct SingleValueDataFixedForWindow : public SingleValueDataFixed +{ +private: + using Self = SingleValueDataFixedForWindow; + using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; + + mutable std::deque * saved_values; + +public: + SingleValueDataFixedForWindow() + : saved_values(nullptr) + {} + + ~SingleValueDataFixedForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (this->has()) + { + auto size = saved_values->size(); + T tmp = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; + } + else + { + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; + } + } + static_cast(to).getData().push_back(tmp); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + this->has_value = false; + saved_values->clear(); + } + + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + this->has_value = false; + } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + return SingleValueDataFixed::changeIfLess(column, row_num, arena); + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataFixed::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + return SingleValueDataFixed::changeIfGreater(column, row_num, arena); + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataFixed::changeIfGreater(to, arena); + } +}; + +struct SingleValueDataStringForWindow : public SingleValueDataString +{ +private: + using Self = SingleValueDataStringForWindow; + + // TODO use std::string is inefficient + mutable std::deque * saved_values{}; + +public: + SingleValueDataStringForWindow() + : saved_values(nullptr) + {} + ~SingleValueDataStringForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (has()) + { + auto elem_num = saved_values->size(); + StringRef value((*saved_values)[0].c_str(), (*saved_values)[0].size()); + for (size_t i = 1; i < elem_num; i++) + { + String cmp_value((*saved_values)[i].c_str(), (*saved_values)[i].size()); + if constexpr (is_min) + { + if (less(cmp_value, value)) + value = (*saved_values)[i]; + } + else + { + if (less(value, cmp_value)) + value = (*saved_values)[i]; + } + } + + static_cast(to).insertDataWithTerminatingZero(value.data, value.size); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + size = -1; + saved_values->clear(); + } + + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + size = -1; + } + + void saveValue(StringRef value) { saved_values->push_back(value.toString()); } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + + return SingleValueDataString::changeIfLess(column, row_num, arena); + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + + return SingleValueDataString::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + + + return SingleValueDataString::changeIfGreater(column, row_num, arena); + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + + return SingleValueDataString::changeIfGreater(to, arena); + } +}; + +struct SingleValueDataGenericForWindow : public SingleValueDataGeneric +{ +private: + using Self = SingleValueDataGenericForWindow; + mutable std::deque * saved_values; + +public: + SingleValueDataGenericForWindow() + : saved_values(nullptr) + {} + ~SingleValueDataGenericForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (has()) + { + auto size = saved_values->size(); + Field tmp = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; + } + else + { + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; + } + } + to.insert(tmp); + } + else + { + to.insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + value = Field(); + saved_values->clear(); + } + + // Only used for window aggregation + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + value = Field(); + } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + if (!has()) + { + change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); + return true; + } + else + { + Field new_value; + column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + + if (new_value < value) + { + value = new_value; + return true; + } + else + return false; + } + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataGeneric::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + if (!has()) + { + change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); + return true; + } + else + { + Field new_value; + column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + + if (new_value > value) + { + value = new_value; + return true; + } + else + return false; + } + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataGeneric::changeIfGreater(to, arena); + } +}; + +template +struct AggregateFunctionMinDataForWindow : Data +{ + using Self = AggregateFunctionMinDataForWindow; + + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) + { + return this->changeIfLess(column, row_num, arena); + } + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } + + void insertResultInto(IColumn & to) const { Data::insertMinResultInto(to); } + + static const char * name() { return "min_for_window"; } +}; + +template +struct AggregateFunctionMaxDataForWindow : Data +{ + using Self = AggregateFunctionMaxDataForWindow; + + void insertResultInto(IColumn & to) const { Data::insertMaxResultInto(to); } + + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) + { + return this->changeIfGreater(column, row_num, arena); + } + + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } + + static const char * name() { return "max_for_window"; } +}; + +} // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h index 003a52b8f47..d4d30db7064 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h @@ -46,6 +46,8 @@ class AggregateFunctionNothing final : public IAggregateFunctionHelper #include #include +#include #include @@ -33,7 +34,7 @@ extern const std::unordered_set hacking_return_non_null_agg_func_names; class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinator { public: - String getName() const override { return "Null"; }; + String getName() const override { return "Null"; } DataTypes transformArguments(const DataTypes & arguments) const override { @@ -141,9 +142,73 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato } }; +class AggregateFunctionCombinatorNullForWindow final : public IAggregateFunctionCombinator +{ +public: + String getName() const override { return "NullForWindow"; } + + DataTypes transformArguments(const DataTypes & arguments) const override + { + size_t size = arguments.size(); + if (size != 1) + throw Exception(fmt::format("Aggregation in window accepts exact 1 argument, but gets {} arguments", size)); + DataTypes res(size); + res[0] = removeNullable(arguments[0]); + return res; + } + + AggregateFunctionPtr transformAggregateFunction( + const AggregateFunctionPtr & nested_function, + const DataTypes & arguments, + const Array &) const override + { + bool has_nullable_types = false; + bool has_null_types = false; + for (const auto & arg_type : arguments) + { + if (arg_type->isNullable()) + { + has_nullable_types = true; + if (arg_type->onlyNull()) + { + has_null_types = true; + break; + } + } + } + + if (nested_function && nested_function->getName() == "count") + { + if (has_nullable_types) + return std::make_shared(arguments[0]); + else + return std::make_shared(); + } + + if (has_null_types) + return std::make_shared(); + + if (nested_function->getReturnType()->canBeInsideNullable()) + { + if (has_nullable_types) + return std::make_shared>(nested_function); + else + return std::make_shared>(nested_function); + } + else + { + if (has_nullable_types) + return std::make_shared>(nested_function); + else + return std::make_shared>(nested_function); + } + } +}; + void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory & factory) { factory.registerCombinator(std::make_shared()); + factory.registerCombinator(std::make_shared()); } } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index 7dbe7db7807..9ea4882434a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -33,7 +33,6 @@ #include #include -#include namespace DB { @@ -91,13 +90,23 @@ class AggregateFunctionNullBase : public IAggregateFunctionHelper } public: - explicit AggregateFunctionNullBase(AggregateFunctionPtr nested_function_) + explicit AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, bool need_counter = false) : nested_function{nested_function_} { if (result_is_nullable) prefix_size = nested_function->alignOfData(); else prefix_size = 0; + + if (result_is_nullable && need_counter) + { + auto tmp = prefix_size; + prefix_size += sizeof(Int64); + if ((prefix_size % tmp) != 0) + prefix_size += tmp - (prefix_size % tmp); + assert((prefix_size % tmp) == 0); + assert(prefix_size >= (tmp + sizeof(Int64))); + } } String getName() const override @@ -505,4 +514,107 @@ class AggregateFunctionNullVariadic final std::array is_nullable; /// Plain array is better than std::vector due to one indirection less. }; +template +class AggregateFunctionNullUnaryForWindow final + : public AggregateFunctionNullBase< + result_is_nullable, + AggregateFunctionNullUnaryForWindow> +{ +public: + explicit AggregateFunctionNullUnaryForWindow(AggregateFunctionPtr nested_function) + : AggregateFunctionNullBase< + result_is_nullable, + AggregateFunctionNullUnaryForWindow>(nested_function, true) + {} + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecreaseImpl(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const + { + if constexpr (is_add) + { + this->addCounter(place); + this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + } + else + { + this->decreaseCounter(place); + if (this->getCounter(place) == 0) + this->resetFlag(place); + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + } + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const + { + if constexpr (input_is_nullable) + { + const auto * column = static_cast(columns[0]); + if (!column->isNullAt(row_num)) + { + const IColumn * nested_column = &column->getNestedColumn(); + addOrDecreaseImpl(place, &nested_column, row_num, arena); + } + } + else + { + addOrDecreaseImpl(place, columns, row_num, arena); + } + } + + void reset(AggregateDataPtr __restrict place) const override + { + this->resetFlag(place); + this->resetCounter(place); + this->nested_function->reset(this->nestedPlace(place)); + } + + void prepareWindow(AggregateDataPtr __restrict place) const override + { + this->nested_function->prepareWindow(this->nestedPlace(place)); + reset(place); + } + +private: + inline void resetFlag(AggregateDataPtr __restrict place) const noexcept { this->initFlag(place); } + + inline void addCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + ++(*counter); + } + + inline void decreaseCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + --(*counter); + assert((*counter) >= 0); + } + + inline Int64 getCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + return *counter; + } + + inline void resetCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + (*counter) = 0; + } +}; + } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index 6b2b96d60b2..7c3935861b5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -27,7 +27,6 @@ #include #include - namespace DB { namespace ErrorCodes @@ -170,7 +169,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(time_arg)) throw Exception{ "Illegal type " + time_arg->getName() + " of first argument of aggregate function " @@ -179,7 +178,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(cond_arg)) throw Exception{ "Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) @@ -298,7 +297,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelperadd(place, columns, row_num, arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + nested_func->decrease(place, columns, row_num, arena); + } + + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index 872d42415ec..f768394ce3e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -42,10 +42,27 @@ struct AggregateFunctionSumAddImpl> } }; +template +struct AggregateFunctionSumMinusImpl +{ + static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(T & lhs, const T & rhs) { lhs -= rhs; } +}; + +template +struct AggregateFunctionSumMinusImpl> +{ + template + static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(Decimal & lhs, const Decimal & rhs) + { + lhs.value -= static_cast(rhs.value); + } +}; + template struct AggregateFunctionSumData { using Impl = AggregateFunctionSumAddImpl; + using DescreaseImpl = AggregateFunctionSumMinusImpl; T sum{}; AggregateFunctionSumData() = default; @@ -56,6 +73,14 @@ struct AggregateFunctionSumData Impl::add(sum, value); } + template + void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(U value) + { + DescreaseImpl::decrease(sum, value); + } + + void NO_SANITIZE_UNDEFINED ALWAYS_INLINE reset() { sum = T(0); } + /// Vectorized version template void NO_SANITIZE_UNDEFINED NO_INLINE addMany(const Value * __restrict ptr, size_t count) @@ -165,6 +190,10 @@ struct AggregateFunctionSumKahanData void ALWAYS_INLINE add(T value) { addImpl(value, sum, compensation); } + void ALWAYS_INLINE decrease(T) { throw Exception("Not implemented yet"); } + + void ALWAYS_INLINE reset() { throw Exception("Not implemented yet"); } + /// Vectorized version template void NO_INLINE addMany(const Value * __restrict ptr, size_t count) @@ -311,7 +340,7 @@ class AggregateFunctionSum final AggregateFunctionSum(PrecType prec, ScaleType scale) { std::tie(result_prec, result_scale) = Name::decimalInfer(prec, scale); - }; + } DataTypePtr getReturnType() const override { @@ -336,6 +365,16 @@ class AggregateFunctionSum final this->data(place).add(column.getData()[row_num]); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + const auto & column = assert_cast(*columns[0]); + this->data(place).decrease(column.getData()[row_num]); + } + + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + + void prepareWindow(AggregateDataPtr __restrict) const override {} + /// Vectorized version when there is no GROUP BY keys. void addBatchSinglePlace( size_t start_offset, diff --git a/dbms/src/AggregateFunctions/HelpersMinMaxAny.h b/dbms/src/AggregateFunctions/HelpersMinMaxAny.h index db96f948610..ad34a592e6d 100644 --- a/dbms/src/AggregateFunctions/HelpersMinMaxAny.h +++ b/dbms/src/AggregateFunctions/HelpersMinMaxAny.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -61,6 +62,40 @@ static IAggregateFunction * createAggregateFunctionSingleValue( return new AggregateFunctionTemplate>(argument_type); } +template