From 3f80eed6e71b6018f913adf160c3db2606f32fb7 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 26 Nov 2024 16:58:53 +0800 Subject: [PATCH 01/32] init --- dbms/src/Common/AlignedBuffer.cpp | 61 ++++++++ dbms/src/Common/AlignedBuffer.h | 48 +++++++ .../DataStreams/WindowBlockInputStream.cpp | 133 +++++++++++++++++- dbms/src/DataStreams/WindowBlockInputStream.h | 13 +- dbms/src/Interpreters/WindowDescription.h | 5 +- dbms/src/WindowFunctions/IWindowFunction.h | 14 +- 6 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 dbms/src/Common/AlignedBuffer.cpp create mode 100644 dbms/src/Common/AlignedBuffer.h diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp new file mode 100644 index 00000000000..7529155f546 --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.cpp @@ -0,0 +1,61 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int CANNOT_ALLOCATE_MEMORY; +} + +void AlignedBuffer::alloc(size_t size, size_t alignment) +{ + void * new_buf; + int res = ::posix_memalign(&new_buf, std::max(alignment, sizeof(void *)), size); + if (0 != res) + throwFromErrno(fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", + size, + alignment), + ErrorCodes::CANNOT_ALLOCATE_MEMORY, + res); + buf = new_buf; +} + +void AlignedBuffer::dealloc() +{ + if (buf) + ::free(buf); +} + +void AlignedBuffer::reset(size_t size, size_t alignment) +{ + dealloc(); + alloc(size, alignment); +} + +AlignedBuffer::AlignedBuffer(size_t size, size_t alignment) +{ + alloc(size, alignment); +} + +AlignedBuffer::~AlignedBuffer() +{ + dealloc(); +} + +} // namespace DB diff --git a/dbms/src/Common/AlignedBuffer.h b/dbms/src/Common/AlignedBuffer.h new file mode 100644 index 00000000000..cebf596cece --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.h @@ -0,0 +1,48 @@ +// Copyright 2024 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB +{ + +/** Aligned piece of memory. + * It can only be allocated and destroyed. + * MemoryTracker is not used. AlignedBuffer is intended for small pieces of memory. + */ +class AlignedBuffer : private boost::noncopyable +{ +private: + void * buf = nullptr; + + void alloc(size_t size, size_t alignment); + void dealloc(); + +public: + AlignedBuffer() = default; + AlignedBuffer(size_t size, size_t alignment); + AlignedBuffer(AlignedBuffer && old) noexcept { std::swap(buf, old.buf); } + ~AlignedBuffer(); + + void reset(size_t size, size_t alignment); + + char * data() { return static_cast(buf); } + const char * data() const { return static_cast(buf); } +}; + +} // namespace DB diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index a31237f8a56..c9c26a56427 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -299,11 +299,30 @@ void WindowTransformAction::initialWorkspaces() WindowFunctionWorkspace workspace; workspace.window_function = window_function_description.window_function; workspace.arguments = window_function_description.arguments; + workspace.argument_column_indices = window_function_description.arguments; + workspace.argument_columns.assign(workspace.argument_column_indices.size(), nullptr); + initialAggregateFunction(workspace, window_function_description); workspaces.push_back(std::move(workspace)); } only_have_row_number = onlyHaveRowNumber(); } +void WindowTransformAction::initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description) +{ + if (window_function_description.aggregate_function == nullptr) + return; + + workspace.aggregate_function = window_function_description.aggregate_function; + const auto & aggregate_function = workspace.aggregate_function; + if (!arena && aggregate_function->allocatesMemoryInArena()) + arena = std::make_unique(); + + workspace.aggregate_function_state.reset( + aggregate_function->sizeOfData(), + aggregate_function->alignOfData()); + aggregate_function->create(workspace.aggregate_function_state.data()); +} + bool WindowBlockInputStream::returnIfCancelledOrKilled() { if (isCancelledOrThrowIfKilled()) @@ -1231,7 +1250,18 @@ void WindowTransformAction::writeOutCurrentRow() for (size_t wi = 0; wi < workspaces.size(); ++wi) { auto & ws = workspaces[wi]; - ws.window_function->windowInsertResultInto(*this, wi, ws.arguments); + if (ws.window_function) + ws.window_function->windowInsertResultInto(*this, wi, ws.arguments); + else + { + const auto & block = blockAt(current_row); + IColumn * result_column = block.output_columns[wi].get(); + const auto * agg_func = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + // TODO add `insertMergeResultInto` function? + agg_func->insertResultInto(buf, *result_column, arena.get()); + } } } @@ -1266,7 +1296,7 @@ bool WindowTransformAction::onlyHaveRowNumber() { for (const auto & workspace : workspaces) { - if (workspace.window_function->getName() != "row_number") + if (workspace.window_function != nullptr && workspace.window_function->getName() != "row_number") return false; } return true; @@ -1315,7 +1345,11 @@ void WindowTransformAction::appendBlock(Block & current_block) // Initialize output columns and add new columns to output block. for (auto & ws : workspaces) { - MutableColumnPtr res = ws.window_function->getReturnType()->createColumn(); + MutableColumnPtr res; + if (ws.window_function != nullptr) + res = ws.window_function->getReturnType()->createColumn(); + else + res = ws.aggregate_function->getReturnType()->createColumn(); res->reserve(window_block.rows); window_block.output_columns.push_back(std::move(res)); } @@ -1323,6 +1357,99 @@ void WindowTransformAction::appendBlock(Block & current_block) window_block.input_columns = current_block.getColumns(); } +// Update the aggregation states after the frame has changed. +void WindowTransformAction::updateAggregationState() +{ + assert(frame_started); + assert(frame_ended); + assert(frame_start <= frame_end); + assert(prev_frame_start <= prev_frame_end); + assert(prev_frame_start <= frame_start); + assert(prev_frame_end <= frame_end); + assert(partition_start <= frame_start); + assert(frame_end <= partition_end); + + bool reset_aggregation = false; + RowNumber rows_to_add_start; + RowNumber rows_to_add_end; + if (frame_start == prev_frame_start) + { + // The frame start didn't change, add the tail rows. + reset_aggregation = false; + rows_to_add_start = prev_frame_end; + rows_to_add_end = frame_end; + } + else + { + // The frame start changed, reset the state and aggregate over the + // entire frame. This can be made per-function after we learn to + // subtract rows from some types of aggregation states, but for now we + // always have to reset when the frame start changes. + reset_aggregation = true; + rows_to_add_start = frame_start; + rows_to_add_end = frame_end; + } + + for (auto & ws : workspaces) + { + if (ws.window_function) + continue; // No need to do anything for true window functions. + + const auto * agg_func = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + if (reset_aggregation) + { + agg_func->destroy(buf); + agg_func->create(buf); + } + + // To achieve better performance, we will have to loop over blocks and + // rows manually, instead of using advanceRowNumber(). + // For this purpose, the past-the-end block can be different than the + // block of the past-the-end row (it's usually the next block). + const auto past_the_end_block = rows_to_add_end.row == 0 + ? rows_to_add_end.block + : rows_to_add_end.block + 1; + + for (auto block_number = rows_to_add_start.block; + block_number < past_the_end_block; + ++block_number) + { + auto & block = blockAt(block_number); + + if (ws.cached_block_number != block_number) + { + for (size_t i = 0; i < ws.argument_column_indices.size(); ++i) + { + ws.argument_columns[i] = block.input_columns[ws.argument_column_indices[i]].get(); + } + ws.cached_block_number = block_number; + } + + // First and last blocks may be processed partially, and other blocks + // are processed in full. + const auto first_row = block_number == rows_to_add_start.block + ? rows_to_add_start.row + : 0; + const auto past_the_end_row = block_number == rows_to_add_end.block + ? rows_to_add_end.row + : block.rows; + + // TODO Add an addBatch analog that can accept a starting offset. + // For now, add the values one by one. + auto * columns = ws.argument_columns.data(); + + // Removing arena.get() from the loop makes it faster somehow... + auto * arena_ptr = arena.get(); + for (auto row = first_row; row < past_the_end_row; ++row) + { + agg_func->add(buf, columns, row, arena_ptr); + } + } + } +} + void WindowTransformAction::tryCalculate() { // Start the calculations. First, advance the partition end. diff --git a/dbms/src/DataStreams/WindowBlockInputStream.h b/dbms/src/DataStreams/WindowBlockInputStream.h index e46424e32ae..ba8dbc2bf7f 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.h +++ b/dbms/src/DataStreams/WindowBlockInputStream.h @@ -52,9 +52,6 @@ struct WindowTransformAction Block tryGetOutputBlock(); void releaseAlreadyOutputWindowBlock(); - void initialWorkspaces(); - void initialPartitionAndOrderColumnIndices(); - Columns & inputAt(const RowNumber & x) { assert(x.block >= first_block_number); @@ -165,6 +162,14 @@ struct WindowTransformAction // distance is left - right. UInt64 distance(RowNumber left, RowNumber right); + void initialWorkspaces(); + void initialPartitionAndOrderColumnIndices(); + void initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description); + + void updateAggregationState(); + + void reinitializeAggFuncBeforeNextPartition(); + public: LoggerPtr log; @@ -244,6 +249,8 @@ struct WindowTransformAction // Auxiliary variable for range frame type when calculating frame_end RowNumber prev_frame_end; + std::unique_ptr arena; + //TODO: used as template parameters bool only_have_row_number = false; }; diff --git a/dbms/src/Interpreters/WindowDescription.h b/dbms/src/Interpreters/WindowDescription.h index 96270416bb2..6fd1021c5e5 100644 --- a/dbms/src/Interpreters/WindowDescription.h +++ b/dbms/src/Interpreters/WindowDescription.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -24,16 +25,16 @@ #include #include - namespace DB { struct WindowFunctionDescription { WindowFunctionPtr window_function; + AggregateFunctionPtr aggregate_function; Array parameters; ColumnNumbers arguments; Names argument_names; - std::string column_name; + String column_name; }; using WindowFunctionDescriptions = std::vector; diff --git a/dbms/src/WindowFunctions/IWindowFunction.h b/dbms/src/WindowFunctions/IWindowFunction.h index e912efa809f..d75dfd9f2a8 100644 --- a/dbms/src/WindowFunctions/IWindowFunction.h +++ b/dbms/src/WindowFunctions/IWindowFunction.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include #include #include @@ -52,8 +54,18 @@ using WindowFunctionPtr = std::shared_ptr; // Runtime data for computing one window function. struct WindowFunctionWorkspace { - // TODO add aggregation function WindowFunctionPtr window_function = nullptr; + AggregateFunctionPtr aggregate_function; + + // Will not be initialized for a pure window function. + mutable AlignedBuffer aggregate_function_state; + + // Argument columns. Be careful, this is a per-block cache. + std::vector argument_columns; + + UInt64 cached_block_number = std::numeric_limits::max(); + + ColumnNumbers argument_column_indices; ColumnNumbers arguments; }; From 171fd461ca0c166c3190e0f45defdebf4707b1ce Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 27 Nov 2024 13:14:31 +0800 Subject: [PATCH 02/32] add gtest --- dbms/src/WindowFunctions/tests/gtest_agg.cpp | 130 +++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 dbms/src/WindowFunctions/tests/gtest_agg.cpp diff --git a/dbms/src/WindowFunctions/tests/gtest_agg.cpp b/dbms/src/WindowFunctions/tests/gtest_agg.cpp new file mode 100644 index 00000000000..cd6b572ae9a --- /dev/null +++ b/dbms/src/WindowFunctions/tests/gtest_agg.cpp @@ -0,0 +1,130 @@ +// Copyright 2024 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + + +namespace DB::tests +{ +class WindowAggFuncTest : public DB::tests::WindowTest +{ +public: + const ASTPtr value_col = col(VALUE_COL_NAME); + + void initializeContext() override + { + ExecutorTest::initializeContext(); + } +}; + +TEST_F(WindowAggFuncTest, windowAggSumTests) +try +{ + { + // rows frame + MockWindowFrame frame; + frame.type = tipb::WindowFrameType::Rows; + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); + frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); + std::vector frame_start_offset{0, 1, 3, 10}; + + std::vector> res{ + {0, 15, 14, 12, 8, 26, 41, 38, 28, 15, 18, 32, 49, 75, 66, 51, 31}, + {0, 15, 15, 14, 12, 26, 41, 41, 38, 28, 18, 33, 52, 80, 75, 66, 51}, + {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 83, 80, 75}, + {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 84, 84, 84}}; + + for (size_t i = 0; i < frame_start_offset.size(); ++i) + { + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); + + executeFunctionAndAssert( + toVec(res[i]), + Sum(value_col), + {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + frame); + } + } + + // TODO uncomment these test after range frame is merged + // { + // // range frame + // MockWindowFrame frame; + // frame.type = tipb::WindowFrameType::Rows; + // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); + // frame.end = buildRangeFrameBound(tipb::WindowBoundType::Following, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, true, 3); + // std::vector frame_start_offset{0, 1, 3, 10}; + + // std::vector> res_not_null{ + // {0, 7, 6, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, + // {0, 7, 7, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, + // {0, 7, 7, 7, 8, 3, 3, 23, 38, 28, 4, 9, 8, 9, 15, 20, 31}, + // {0, 7, 7, 7, 15, 3, 3, 26, 41, 38, 4, 9, 9, 18, 29, 35, 31}}; + + // for (size_t i = 0; i < frame_start_offset.size(); ++i) + // { + // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); + + // executeFunctionAndAssert( + // toVec(res_not_null[i]), + // Sum(value_col), + // {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + // toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + // toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + // frame); + // } + // } +} +CATCH + +TEST_F(WindowAggFuncTest, windowAggCountTests) +try +{ + { + // rows frame + MockWindowFrame frame; + frame.type = tipb::WindowFrameType::Rows; + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); + frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); + std::vector frame_start_offset{0, 1, 3, 10}; + + std::vector> res{ + {1, 4, 3, 2, 1, 4, 4, 3, 2, 1, 4, 4, 4, 4, 3, 2, 1}, + {1, 4, 4, 3, 2, 4, 5, 4, 3, 2, 4, 5, 5, 5, 4, 3, 2}, + {1, 4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 5, 6, 7, 6, 5, 4}, + {1, 4, 4, 4, 4, 4, 5, 5, 5, 5, 4, 5, 6, 7, 7, 7, 7}}; + + for (size_t i = 0; i < frame_start_offset.size(); ++i) + { + frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); + + executeFunctionAndAssert( + toVec(res[i]), + Count(value_col), + {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), + toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), + toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, + frame); + } + } + // TODO add range frame tests after that is merged +} +CATCH +} // namespace DB::tests From 943c578dd6d3b54b03b9c9afa3153cdc754c6661 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 27 Nov 2024 13:17:06 +0800 Subject: [PATCH 03/32] add ft --- tests/fullstack-test/mpp/window_agg.test | 139 +++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 tests/fullstack-test/mpp/window_agg.test diff --git a/tests/fullstack-test/mpp/window_agg.test b/tests/fullstack-test/mpp/window_agg.test new file mode 100644 index 00000000000..701eb4916f2 --- /dev/null +++ b/tests/fullstack-test/mpp/window_agg.test @@ -0,0 +1,139 @@ +# Copyright 2024 PingCAP, Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.agg; +mysql> create table test.agg(p int not null, o int not null, v int not null); +mysql> insert into test.agg (p, o, v) values (0, 0, 0), (1, 1, 1), (1, 2, 2), (1, 4, 4), (1, 8, 8), (2, 0, 0), (2, 3, 3), (2, 10, 10), (2, 13, 13), (2, 15, 15), (3, 1, 1), (3, 3, 3), (3, 5, 5), (3, 9, 9), (3, 15, 15), (3, 20, 20), (3, 31, 31); +mysql> alter table agg set tiflash replica 1; + +func> wait_table test test.agg + +mysql> use test; set tidb_enforce_mpp=1; + +//TODO ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg, ast.AggFuncMax, ast.AggFuncMin ... + +mysql> use test; set tidb_enforce_mpp=1; select *, sum(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 3 | +| 1 | 2 | 2 | 7 | +| 1 | 4 | 4 | 14 | +| 1 | 8 | 8 | 12 | +| 2 | 0 | 0 | 3 | +| 2 | 3 | 3 | 13 | +| 2 | 10 | 10 | 26 | +| 2 | 13 | 13 | 38 | +| 2 | 15 | 15 | 28 | +| 3 | 1 | 1 | 4 | +| 3 | 3 | 3 | 9 | +| 3 | 5 | 5 | 17 | +| 3 | 9 | 9 | 29 | +| 3 | 15 | 15 | 44 | +| 3 | 20 | 20 | 66 | +| 3 | 31 | 31 | 51 | ++---+----+----+------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, count(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+---+ +| p | o | v | a | ++---+----+----+---+ +| 0 | 0 | 0 | 1 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 2 | 3 | +| 1 | 4 | 4 | 3 | +| 1 | 8 | 8 | 2 | +| 2 | 0 | 0 | 2 | +| 2 | 3 | 3 | 3 | +| 2 | 10 | 10 | 3 | +| 2 | 13 | 13 | 3 | +| 2 | 15 | 15 | 2 | +| 3 | 1 | 1 | 2 | +| 3 | 3 | 3 | 3 | +| 3 | 5 | 5 | 3 | +| 3 | 9 | 9 | 3 | +| 3 | 15 | 15 | 3 | +| 3 | 20 | 20 | 3 | +| 3 | 31 | 31 | 2 | ++---+----+----+---+ + +mysql> use test; set tidb_enforce_mpp=1; select *, avg(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+---------+ +| p | o | v | a | ++---+----+----+---------+ +| 0 | 0 | 0 | 0.0000 | +| 1 | 1 | 1 | 1.5000 | +| 1 | 2 | 2 | 2.3333 | +| 1 | 4 | 4 | 4.6666 | +| 1 | 8 | 8 | 6.0000 | +| 2 | 0 | 0 | 1.5000 | +| 2 | 3 | 3 | 4.3333 | +| 2 | 10 | 10 | 8.6666 | +| 2 | 13 | 13 | 12.6666 | +| 2 | 15 | 15 | 14.0000 | +| 3 | 1 | 1 | 2.0000 | +| 3 | 3 | 3 | 3.0000 | +| 3 | 5 | 5 | 5.6666 | +| 3 | 9 | 9 | 9.6666 | +| 3 | 15 | 15 | 14.6666 | +| 3 | 20 | 20 | 22.0000 | +| 3 | 31 | 31 | 25.5000 | ++---+----+----+---------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, min(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 1 | +| 1 | 2 | 2 | 1 | +| 1 | 4 | 4 | 2 | +| 1 | 8 | 8 | 4 | +| 2 | 0 | 0 | 0 | +| 2 | 3 | 3 | 0 | +| 2 | 10 | 10 | 3 | +| 2 | 13 | 13 | 10 | +| 2 | 15 | 15 | 13 | +| 3 | 1 | 1 | 1 | +| 3 | 3 | 3 | 1 | +| 3 | 5 | 5 | 3 | +| 3 | 9 | 9 | 5 | +| 3 | 15 | 15 | 9 | +| 3 | 20 | 20 | 15 | +| 3 | 31 | 31 | 20 | ++---+----+----+------+ + +mysql> use test; set tidb_enforce_mpp=1; select *, max(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; ++---+----+----+------+ +| p | o | v | a | ++---+----+----+------+ +| 0 | 0 | 0 | 0 | +| 1 | 1 | 1 | 2 | +| 1 | 2 | 2 | 4 | +| 1 | 4 | 4 | 8 | +| 1 | 8 | 8 | 8 | +| 2 | 0 | 0 | 3 | +| 2 | 3 | 3 | 10 | +| 2 | 10 | 10 | 13 | +| 2 | 13 | 13 | 15 | +| 2 | 15 | 15 | 15 | +| 3 | 1 | 1 | 3 | +| 3 | 3 | 3 | 5 | +| 3 | 5 | 5 | 9 | +| 3 | 9 | 9 | 15 | +| 3 | 15 | 15 | 20 | +| 3 | 20 | 20 | 31 | +| 3 | 31 | 31 | 31 | ++---+----+----+------+ From e0b05a2d3f2e767c8f424e11c49e733fcf011c2d Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 27 Nov 2024 14:01:30 +0800 Subject: [PATCH 04/32] refine test framework --- dbms/src/Common/AlignedBuffer.cpp | 9 +- .../DataStreams/WindowBlockInputStream.cpp | 24 +-- dbms/src/DataStreams/WindowBlockInputStream.h | 6 +- dbms/src/Debug/MockExecutor/WindowBinder.cpp | 182 +++++++++++++----- dbms/src/Debug/MockExecutor/WindowBinder.h | 4 + dbms/src/WindowFunctions/tests/gtest_agg.cpp | 5 +- 6 files changed, 157 insertions(+), 73 deletions(-) diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp index 7529155f546..c9783bc3dfc 100644 --- a/dbms/src/Common/AlignedBuffer.cpp +++ b/dbms/src/Common/AlignedBuffer.cpp @@ -28,11 +28,10 @@ void AlignedBuffer::alloc(size_t size, size_t alignment) void * new_buf; int res = ::posix_memalign(&new_buf, std::max(alignment, sizeof(void *)), size); if (0 != res) - throwFromErrno(fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", - size, - alignment), - ErrorCodes::CANNOT_ALLOCATE_MEMORY, - res); + throwFromErrno( + fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", size, alignment), + ErrorCodes::CANNOT_ALLOCATE_MEMORY, + res); buf = new_buf; } diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index c9c26a56427..3e4ee5b5134 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -307,7 +307,9 @@ void WindowTransformAction::initialWorkspaces() only_have_row_number = onlyHaveRowNumber(); } -void WindowTransformAction::initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description) +void WindowTransformAction::initialAggregateFunction( + WindowFunctionWorkspace & workspace, + const WindowFunctionDescription & window_function_description) { if (window_function_description.aggregate_function == nullptr) return; @@ -317,9 +319,7 @@ void WindowTransformAction::initialAggregateFunction(WindowFunctionWorkspace & w if (!arena && aggregate_function->allocatesMemoryInArena()) arena = std::make_unique(); - workspace.aggregate_function_state.reset( - aggregate_function->sizeOfData(), - aggregate_function->alignOfData()); + workspace.aggregate_function_state.reset(aggregate_function->sizeOfData(), aggregate_function->alignOfData()); aggregate_function->create(workspace.aggregate_function_state.data()); } @@ -1408,13 +1408,9 @@ void WindowTransformAction::updateAggregationState() // rows manually, instead of using advanceRowNumber(). // For this purpose, the past-the-end block can be different than the // block of the past-the-end row (it's usually the next block). - const auto past_the_end_block = rows_to_add_end.row == 0 - ? rows_to_add_end.block - : rows_to_add_end.block + 1; + const auto past_the_end_block = rows_to_add_end.row == 0 ? rows_to_add_end.block : rows_to_add_end.block + 1; - for (auto block_number = rows_to_add_start.block; - block_number < past_the_end_block; - ++block_number) + for (auto block_number = rows_to_add_start.block; block_number < past_the_end_block; ++block_number) { auto & block = blockAt(block_number); @@ -1429,12 +1425,8 @@ void WindowTransformAction::updateAggregationState() // First and last blocks may be processed partially, and other blocks // are processed in full. - const auto first_row = block_number == rows_to_add_start.block - ? rows_to_add_start.row - : 0; - const auto past_the_end_row = block_number == rows_to_add_end.block - ? rows_to_add_end.row - : block.rows; + const auto first_row = block_number == rows_to_add_start.block ? rows_to_add_start.row : 0; + const auto past_the_end_row = block_number == rows_to_add_end.block ? rows_to_add_end.row : block.rows; // TODO Add an addBatch analog that can accept a starting offset. // For now, add the values one by one. diff --git a/dbms/src/DataStreams/WindowBlockInputStream.h b/dbms/src/DataStreams/WindowBlockInputStream.h index ba8dbc2bf7f..d69b9ca98ed 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.h +++ b/dbms/src/DataStreams/WindowBlockInputStream.h @@ -164,7 +164,9 @@ struct WindowTransformAction void initialWorkspaces(); void initialPartitionAndOrderColumnIndices(); - void initialAggregateFunction(WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description); + void initialAggregateFunction( + WindowFunctionWorkspace & workspace, + const WindowFunctionDescription & window_function_description); void updateAggregationState(); @@ -265,7 +267,7 @@ class WindowBlockInputStream : public IProfilingBlockInputStream const WindowDescription & window_description_, const String & req_id); - Block getHeader() const override { return action.output_header; }; + Block getHeader() const override { return action.output_header; } String getName() const override { return NAME; } diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index 32f2e6b4d18..78bb8d168f9 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -32,12 +32,10 @@ namespace // // Other window or aggregation functions always return nullable column and we need to // remove the not null flag for them. -void setWindowFieldType( - const tipb::ExprType window_sig, - tipb::FieldType * window_field_type, - tipb::Expr * window_expr, - const int32_t collator_id) +void setFieldTypeForWindowFunc(tipb::Expr * window_expr, const tipb::ExprType window_sig, const int32_t collator_id) { + window_expr->set_tp(window_sig); + auto * window_field_type = window_expr->mutable_field_type(); switch (window_sig) { case tipb::ExprType::Lead: @@ -86,6 +84,55 @@ void setWindowFieldType( } } +void setFieldTypeForAggFunc( + const DB::ASTFunction * func, + tipb::Expr * expr, + const tipb::ExprType agg_sig, + int32_t collator_id) +{ + expr->set_tp(agg_sig); + if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum) + { + auto * ft = expr->mutable_field_type(); + ft->set_tp(TiDB::TypeLongLong); + ft->set_flag(TiDB::ColumnFlagNotNull); + } + else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max) + { + if (expr->children_size() != 1) + throw Exception(fmt::format("Agg function({}) only accept 1 argument", func->name)); + + auto * ft = expr->mutable_field_type(); + ft->set_tp(expr->children(0).field_type().tp()); + ft->set_decimal(expr->children(0).field_type().decimal()); + ft->set_flag(expr->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull)); + ft->set_collate(collator_id); + } + else + { + throw Exception("Window does not support this agg function"); + } + + expr->set_aggfuncmode(tipb::AggFunctionMode::FinalMode); +} + +void setFieldType(const DB::ASTFunction * func, tipb::Expr * expr, int32_t collator_id) +{ + auto window_sig_it = tests::window_func_name_to_sig.find(func->name); + if (window_sig_it != tests::window_func_name_to_sig.end()) + { + setFieldTypeForWindowFunc(expr, window_sig_it->second, collator_id); + return; + } + + auto agg_sig_it = tests::agg_func_name_to_sig.find(func->name); + if (agg_sig_it == tests::agg_func_name_to_sig.end()) + throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); + + auto agg_sig = agg_sig_it->second; + setFieldTypeForAggFunc(func, expr, agg_sig, collator_id); +} + tipb::ExprType getWindowSig(const String & window_func_name) { auto window_sig_it = tests::window_func_name_to_sig.find(window_func_name); @@ -155,7 +202,7 @@ bool WindowBinder::toTiPBExecutor( auto window_sig = getWindowSig(window_func->name); window_expr->set_tp(window_sig); - setWindowFieldType(window_sig, window_expr->mutable_field_type(), window_expr, collator_id); + setFieldType(window_func, window_expr, collator_id); } for (const auto & child : order_by_exprs) @@ -183,6 +230,88 @@ bool WindowBinder::toTiPBExecutor( return children[0]->toTiPBExecutor(window->mutable_child(), collator_id, mpp_info, context); } +void setColumnInfoForAgg( + TiDB::ColumnInfo & ci, + const DB::ASTFunction * func, + const std::vector & children_ci) +{ + // TODO: Other agg func. + if (func->name == "count") + { + ci.tp = TiDB::TypeLongLong; + ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull; + } + else if (func->name == "max" || func->name == "min" || func->name == "sum") + { + ci = children_ci[0]; + ci.flag &= ~TiDB::ColumnFlagNotNull; + } + else + { + throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); + } +} + +void setColumnInfoForWindowFunc( + TiDB::ColumnInfo & ci, + const DB::ASTFunction * func, + const std::vector & children_ci, + tipb::ExprType expr_type) +{ + // TODO: add more window functions + switch (expr_type) + { + case tipb::ExprType::RowNumber: + case tipb::ExprType::Rank: + case tipb::ExprType::DenseRank: + { + ci.tp = TiDB::TypeLongLong; + ci.flag = TiDB::ColumnFlagBinary; + break; + } + case tipb::ExprType::Lead: + case tipb::ExprType::Lag: + { + // TODO handling complex situations + // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) + assert(!children_ci.empty() && children_ci.size() <= 3); + if (children_ci.size() < 3) + { + ci = children_ci[0]; + ci.clearNotNullFlag(); + } + else + { + assert(children_ci[0].tp == children_ci[2].tp); + ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; + } + break; + } + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + { + ci = children_ci[0]; + break; + } + default: + throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); + } +} + +TiDB::ColumnInfo createColumnInfo(const DB::ASTFunction * func, const std::vector & children_ci) +{ + TiDB::ColumnInfo ci; + auto iter = tests::window_func_name_to_sig.find(func->name); + if (iter != tests::window_func_name_to_sig.end()) + { + setColumnInfoForWindowFunc(ci, func, children_ci, iter->second); + return ci; + } + + setColumnInfoForAgg(ci, func, children_ci); + return ci; +} + ExecutorBinderPtr compileWindow( ExecutorBinderPtr input, size_t & executor_index, @@ -241,46 +370,7 @@ ExecutorBinderPtr compileWindow( { children_ci.push_back(compileExpr(input->output_schema, arg)); } - // TODO: add more window functions - TiDB::ColumnInfo ci; - switch (tests::window_func_name_to_sig[func->name]) - { - case tipb::ExprType::RowNumber: - case tipb::ExprType::Rank: - case tipb::ExprType::DenseRank: - { - ci.tp = TiDB::TypeLongLong; - ci.flag = TiDB::ColumnFlagBinary; - break; - } - case tipb::ExprType::Lead: - case tipb::ExprType::Lag: - { - // TODO handling complex situations - // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) - assert(!children_ci.empty() && children_ci.size() <= 3); - if (children_ci.size() < 3) - { - ci = children_ci[0]; - ci.clearNotNullFlag(); - } - else - { - assert(children_ci[0].tp == children_ci[2].tp); - ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; - } - break; - } - case tipb::ExprType::FirstValue: - case tipb::ExprType::LastValue: - { - ci = children_ci[0]; - ci.clearNotNullFlag(); - break; - } - default: - throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); - } + TiDB::ColumnInfo ci = createColumnInfo(func, children_ci); output_schema.emplace_back(std::make_pair(func->getColumnName(), ci)); } } diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.h b/dbms/src/Debug/MockExecutor/WindowBinder.h index 11324167d18..87b340fb834 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.h +++ b/dbms/src/Debug/MockExecutor/WindowBinder.h @@ -165,6 +165,10 @@ class WindowBinder : public ExecutorBinder const MPPInfo & mpp_info, const Context & context) override; +private: + void buildWindowFunc(); + void buildAggFunc(); + private: std::vector func_descs; std::vector partition_by_exprs; diff --git a/dbms/src/WindowFunctions/tests/gtest_agg.cpp b/dbms/src/WindowFunctions/tests/gtest_agg.cpp index cd6b572ae9a..84611237967 100644 --- a/dbms/src/WindowFunctions/tests/gtest_agg.cpp +++ b/dbms/src/WindowFunctions/tests/gtest_agg.cpp @@ -26,10 +26,7 @@ class WindowAggFuncTest : public DB::tests::WindowTest public: const ASTPtr value_col = col(VALUE_COL_NAME); - void initializeContext() override - { - ExecutorTest::initializeContext(); - } + void initializeContext() override { ExecutorTest::initializeContext(); } }; TEST_F(WindowAggFuncTest, windowAggSumTests) From 75d2f12fa19d75c5171b81580d1f35e1da0627a4 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 27 Nov 2024 15:07:44 +0800 Subject: [PATCH 05/32] fix compilation phase --- dbms/src/Debug/MockExecutor/WindowBinder.cpp | 11 --- .../Coprocessor/DAGExpressionAnalyzer.cpp | 75 +++++++++++++++---- .../Flash/Coprocessor/DAGExpressionAnalyzer.h | 5 +- 3 files changed, 65 insertions(+), 26 deletions(-) diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index 78bb8d168f9..bbcb6e8c7ad 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -133,15 +133,6 @@ void setFieldType(const DB::ASTFunction * func, tipb::Expr * expr, int32_t colla setFieldTypeForAggFunc(func, expr, agg_sig, collator_id); } -tipb::ExprType getWindowSig(const String & window_func_name) -{ - auto window_sig_it = tests::window_func_name_to_sig.find(window_func_name); - if (window_sig_it == tests::window_func_name_to_sig.end()) - throw Exception(fmt::format("Unsupported window function {}", window_func_name), ErrorCodes::LOGICAL_ERROR); - - return window_sig_it->second; -} - void setWindowFrame(MockWindowFrame & frame, tipb::Window * window) { if (frame.type.has_value()) @@ -200,8 +191,6 @@ bool WindowBinder::toTiPBExecutor( astToPB(input_schema, arg, window_expr->add_children(), collator_id, context); } - auto window_sig = getWindowSig(window_func->name); - window_expr->set_tp(window_sig); setFieldType(window_func, window_expr, collator_id); } diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index f419a693ff0..83810a680d2 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -177,17 +177,19 @@ void appendAggDescription( /// Generate WindowFunctionDescription and append it to WindowDescription if need. void appendWindowDescription( + const Context & context, const Names & arg_names, const DataTypes & arg_types, TiDB::TiDBCollators & arg_collators, - const String & window_func_name, + const String & func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns) + NamesAndTypes & window_columns, + bool is_agg) { assert(arg_names.size() == arg_collators.size() && arg_names.size() == arg_types.size()); - String func_string = genFuncString(window_func_name, arg_names, arg_collators); + String func_string = genFuncString(func_name, arg_names, arg_collators); if (auto duplicated_return_type = findDuplicateAggWindowFunc(func_string, window_description.window_functions_descriptions)) { @@ -199,13 +201,44 @@ void appendWindowDescription( WindowFunctionDescription window_function_description; window_function_description.argument_names = arg_names; window_function_description.column_name = func_string; - window_function_description.window_function = WindowFunctionFactory::instance().get(window_func_name, arg_types); - DataTypePtr result_type = window_function_description.window_function->getReturnType(); + + DataTypePtr result_type; + if (is_agg) + { + window_function_description.aggregate_function + = AggregateFunctionFactory::instance().get(context, func_name, arg_types, {}, 0, true); + result_type = window_function_description.aggregate_function->getReturnType(); + } + else + { + window_function_description.window_function = WindowFunctionFactory::instance().get(func_name, arg_types); + result_type = window_function_description.window_function->getReturnType(); + } + window_description.window_functions_descriptions.emplace_back(std::move(window_function_description)); window_columns.emplace_back(func_string, result_type); source_columns.emplace_back(func_string, result_type); } +bool isWindowFunction(const tipb::ExprType expr_type) +{ + switch (expr_type) + { + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + case tipb::ExprType::RowNumber: + case tipb::ExprType::Rank: + case tipb::ExprType::DenseRank: + case tipb::ExprType::CumeDist: + case tipb::ExprType::PercentRank: + case tipb::ExprType::Ntile: + case tipb::ExprType::NthValue: + return true; + default: + return false; + } +} + void setAuxiliaryColumnInfoImpl( const String & aux_col_name, const Block & tmp_block, @@ -831,22 +864,25 @@ void DAGExpressionAnalyzer::buildLeadLag( } appendWindowDescription( + context, arg_names, arg_types, arg_collators, window_func_name, window_description, source_columns, - window_columns); + window_columns, + false); } -void DAGExpressionAnalyzer::buildCommonWindowFunc( +void DAGExpressionAnalyzer::buildWindowOrAggFuncImpl( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns) + NamesAndTypes & window_columns, + bool is_agg) { auto child_size = expr.children_size(); Names arg_names; @@ -858,13 +894,15 @@ void DAGExpressionAnalyzer::buildCommonWindowFunc( } appendWindowDescription( + context, arg_names, arg_types, arg_collators, window_func_name, window_description, source_columns, - window_columns); + window_columns, + is_agg); } // This function will add new window function culumns to source_column @@ -879,7 +917,6 @@ void DAGExpressionAnalyzer::appendWindowColumns( NamesAndTypes window_columns; for (const tipb::Expr & expr : window.func_desc()) { - RUNTIME_CHECK_MSG(isWindowFunctionExpr(expr), "Now Window Operator only support window function."); if (expr.tp() == tipb::ExprType::Lead || expr.tp() == tipb::ExprType::Lag) { buildLeadLag( @@ -890,15 +927,27 @@ void DAGExpressionAnalyzer::appendWindowColumns( source_columns, window_columns); } - else + else if (isWindowFunction(expr.tp())) { - buildCommonWindowFunc( + buildWindowOrAggFuncImpl( expr, actions, getWindowFunctionName(expr), window_description, source_columns, - window_columns); + window_columns, + false); + } + else + { + buildWindowOrAggFuncImpl( + expr, + actions, + getAggFunctionName(expr), + window_description, + source_columns, + window_columns, + true); } } window_description.add_columns = window_columns; diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h index 8ef4dbc0b78..b79a2743f9c 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h @@ -228,13 +228,14 @@ class DAGExpressionAnalyzer : private boost::noncopyable NamesAndTypes & source_columns, NamesAndTypes & window_columns); - void buildCommonWindowFunc( + void buildWindowOrAggFuncImpl( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns); + NamesAndTypes & window_columns, + bool is_agg); void fillArgumentDetail( const ExpressionActionsPtr & actions, From 7022607850680cc2ff7638b12768bfbf529f5dd8 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 4 Dec 2024 13:44:11 +0800 Subject: [PATCH 06/32] init --- .../AggregateFunctionArgMinMax.h | 6 ++ .../AggregateFunctionArray.h | 21 ++++- .../AggregateFunctions/AggregateFunctionAvg.h | 11 +++ .../AggregateFunctionBitwise.h | 6 ++ .../AggregateFunctionCount.h | 19 +++++ .../AggregateFunctionForEach.h | 21 ++++- .../AggregateFunctionGroupArray.h | 6 ++ .../AggregateFunctionGroupArrayInsertAt.h | 6 ++ .../AggregateFunctionGroupConcat.h | 32 ++++++-- .../AggregateFunctionGroupUniqArray.h | 19 ++++- .../AggregateFunctions/AggregateFunctionIf.h | 7 ++ .../AggregateFunctionMaxIntersections.h | 10 ++- .../AggregateFunctionMerge.h | 6 ++ .../AggregateFunctionMinMaxAny.cpp | 10 +-- .../AggregateFunctionMinMaxAny.h | 6 ++ .../AggregateFunctionNothing.h | 2 + .../AggregateFunctionNull.h | 59 ++++++++++++-- .../AggregateFunctionQuantile.h | 3 + .../AggregateFunctionSequenceMatch.h | 15 +++- .../AggregateFunctionState.h | 6 ++ .../AggregateFunctionStatistics.h | 6 ++ .../AggregateFunctionStatisticsSimple.h | 3 + .../AggregateFunctions/AggregateFunctionSum.h | 31 +++++++ .../AggregateFunctionSumMap.h | 6 ++ .../AggregateFunctionTopK.h | 5 ++ .../AggregateFunctionUniq.h | 12 +++ .../AggregateFunctionUniqUpTo.h | 8 ++ .../AggregateFunctions/IAggregateFunction.h | 11 +++ .../DataStreams/WindowBlockInputStream.cpp | 80 ++++--------------- dbms/src/DataStreams/WindowBlockInputStream.h | 55 +++++++++++++ dbms/src/Debug/MockExecutor/WindowBinder.cpp | 9 +-- .../Coprocessor/DAGExpressionAnalyzer.cpp | 6 +- 32 files changed, 399 insertions(+), 104 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h index f252c61b577..cba7bb9facd 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h @@ -72,6 +72,12 @@ class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelperdata(place).result.change(*columns[0], row_num, arena); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + // TODO move to helper + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { if (this->data(place).value.changeIfBetter(this->data(rhs).value, arena)) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index eacd499f86d..6439ff91688 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -67,13 +67,25 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperalignOfData(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { const IColumn * nested[num_arguments]; for (size_t i = 0; i < num_arguments; ++i) nested[i] = &static_cast(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -82,7 +94,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -92,7 +104,10 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperadd(place, nested, i, arena); + if constexpr (is_add) + nested_func->add(place, nested, i, arena); + else + nested_func->decrease(place, nested, i, arena); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index 1879c3bca4a..d4a069d2006 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -78,6 +78,17 @@ class AggregateFunctionAvg final ++this->data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + if constexpr (IsDecimal) + this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; + else + { + this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; + } + --this->data(place).count; + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).sum += this->data(rhs).sum; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h index 0c37c524770..28508aaf570 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h @@ -62,6 +62,12 @@ class AggregateFunctionBitwise final : public IAggregateFunctionDataHelperdata(place).update(static_cast &>(*columns[0]).getData()[row_num]); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).update(this->data(rhs).value); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index 3bf0c497a80..b93c4bdab03 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -52,6 +52,11 @@ class AggregateFunctionCount final ++data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override + { + --data(place).count; + } + void addBatchSinglePlace( size_t start_offset, size_t batch_size, @@ -173,6 +178,11 @@ class AggregateFunctionCountNotNullUnary final data(place).count += !static_cast(*columns[0]).isNullAt(row_num); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + data(place).count -= !static_cast(*columns[0]).isNullAt(row_num); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; @@ -234,6 +244,15 @@ class AggregateFunctionCountNotNullVariadic final ++data(place).count; } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + for (size_t i = 0; i < number_of_arguments; ++i) + if (is_nullable[i] && static_cast(*columns[i]).isNullAt(row_num)) + return; + + --data(place).count; + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h index 4f6bdf4f48e..b074bfe5d75 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h @@ -149,13 +149,25 @@ class AggregateFunctionForEach final bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { const IColumn * nested[num_arguments]; for (size_t i = 0; i < num_arguments; ++i) nested[i] = &static_cast(*columns[i]).getData(); - const ColumnArray & first_array_column = static_cast(*columns[0]); + const auto & first_array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = first_array_column.getOffsets(); size_t begin = row_num == 0 ? 0 : offsets[row_num - 1]; @@ -164,7 +176,7 @@ class AggregateFunctionForEach final /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. for (size_t i = 1; i < num_arguments; ++i) { - const ColumnArray & ith_column = static_cast(*columns[i]); + const auto & ith_column = static_cast(*columns[i]); const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) @@ -178,7 +190,10 @@ class AggregateFunctionForEach final char * nested_state = state.array_of_aggregate_datas; for (size_t i = begin; i < end; ++i) { - nested_func->add(nested_state, nested, i, arena); + if constexpr (is_add) + nested_func->add(nested_state, nested, i, arena); + else + nested_func->decrease(nested_state, nested, i, arena); nested_state += nested_size_of_data; } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index c80babc0502..815b0baf275 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -80,6 +80,9 @@ class GroupArrayNumericImpl final this->data(place).value.push_back(static_cast &>(*columns[0]).getData()[row_num], arena); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { auto & cur_elems = this->data(place); @@ -284,6 +287,9 @@ class GroupArrayGeneralListImpl final ++data(place).elems; } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { /// It is sadly, but rhs's Arena could be destroyed diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h index 1a71df32e41..2ae55afea21 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h @@ -150,6 +150,12 @@ class AggregateFunctionGroupArrayInsertAtGeneric final columns[0]->get(row_num, arr[position]); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { Array & arr_lhs = data(place).value; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h index baee40e62d4..28602c01766 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h @@ -116,19 +116,24 @@ class AggregateFunctionGroupConcat final DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; } - /// reject nulls before add() of nested agg - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + /// reject nulls before add()/decrease() of nested agg + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (only_one_column) { if (is_nullable[0]) { - const ColumnNullable * column = static_cast(columns[0]); + const auto * column = static_cast(columns[0]); if (!column->isNullAt(row_num)) { this->setFlag(place); const IColumn * nested_column = &column->getNestedColumn(); - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); } return; } @@ -136,12 +141,12 @@ class AggregateFunctionGroupConcat final else { /// remove the row with null, except for sort columns - const ColumnTuple & tuple = static_cast(*columns[0]); + const auto & tuple = static_cast(*columns[0]); for (size_t i = 0; i < number_of_concat_items; ++i) { if (is_nullable[i]) { - const ColumnNullable & nullable_col = static_cast(tuple.getColumn(i)); + const auto & nullable_col = static_cast(tuple.getColumn(i)); if (nullable_col.isNullAt(row_num)) { /// If at least one column has a null value in the current row, @@ -152,7 +157,20 @@ class AggregateFunctionGroupConcat final } } this->setFlag(place); - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); } void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index 06dd57edf66..74d2beb01c6 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -61,6 +61,12 @@ class AggregateFunctionGroupUniqArray this->data(place).value.insert(assert_cast &>(*columns[0]).getData()[row_num]); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + const auto & key = AggregateFunctionGroupUniqArrayData::Set::Cell::getKey(assert_cast &>(*columns[0]).getData()[row_num]); + this->data(place).value.erase(key); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); @@ -82,7 +88,7 @@ class AggregateFunctionGroupUniqArray void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - ColumnArray & arr_to = assert_cast(to); + auto & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); const typename State::Set & set = this->data(place).value; @@ -171,6 +177,15 @@ class AggregateFunctionGroupUniqArrayGeneric set.emplace(key_holder, it, inserted); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + auto & set = this->data(place).value; + + auto key_holder = getKeyHolder(*columns[0], row_num, *arena); + auto key = keyHolderGetKey(key_holder); + set.erase(key); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { auto & cur_set = this->data(place).value; @@ -188,7 +203,7 @@ class AggregateFunctionGroupUniqArrayGeneric void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - ColumnArray & arr_to = assert_cast(to); + auto & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & data_to = arr_to.getData(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index fbd5bd242d8..e8baf4114aa 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -76,6 +76,13 @@ class AggregateFunctionIf final : public IAggregateFunctionHelperadd(place, columns, row_num, arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + if (static_cast(*columns[num_arguments - 1]).getData()[row_num]) + nested_func->decrease(place, columns, row_num, arena); + } + void addBatch( size_t start_offset, size_t batch_size, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h index ad029472f2c..36df63b35de 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h @@ -103,10 +103,16 @@ class AggregateFunctionIntersectionsMax final PointType right = static_cast &>(*columns[1]).getData()[row_num]; if (!isNaN(left)) - this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena); + this->data(place).value.push_back(std::make_pair(left, static_cast(1)), arena); if (!isNaN(right)) - this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena); + this->data(place).value.push_back(std::make_pair(right, static_cast(-1)), arena); + } + + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 710443509dc..1f04e87edb0 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -64,6 +64,12 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelpermerge(place, static_cast(*columns[0]).getData()[row_num], arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + nested_func->merge(place, static_cast(*columns[0]).getData()[row_num], arena); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp index 16bc94aee1b..34d56c0f4a4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp @@ -123,14 +123,14 @@ AggregateFunctionPtr createAggregateFunctionArgMax( void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory) { - factory.registerFunction("any", createAggregateFunctionAny); + factory.registerFunction("any", createAggregateFunctionAny); // TODO no use factory.registerFunction("first_row", createAggregateFunctionFirstRow); - factory.registerFunction("anyLast", createAggregateFunctionAnyLast); - factory.registerFunction("anyHeavy", createAggregateFunctionAnyHeavy); + factory.registerFunction("anyLast", createAggregateFunctionAnyLast); // TODO no use + factory.registerFunction("anyHeavy", createAggregateFunctionAnyHeavy); // TODO no use factory.registerFunction("min", createAggregateFunctionMin, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("max", createAggregateFunctionMax, AggregateFunctionFactory::CaseInsensitive); - factory.registerFunction("argMin", createAggregateFunctionArgMin); - factory.registerFunction("argMax", createAggregateFunctionArgMax); + factory.registerFunction("argMin", createAggregateFunctionArgMin); // TODO no use + factory.registerFunction("argMax", createAggregateFunctionArgMax); // TODO no use } } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 2f41c931b91..cb04bab9aa2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -751,6 +751,12 @@ class AggregateFunctionsSingleValue final this->data(place).changeIfBetter(*columns[0], row_num, arena); } + // TODO implement decrease + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { this->data(place).changeIfBetter(this->data(rhs), arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h index 003a52b8f47..d4d30db7064 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h @@ -46,6 +46,8 @@ class AggregateFunctionNothing final : public IAggregateFunctionHelper #include -#include namespace DB { @@ -267,6 +266,17 @@ class AggregateFunctionFirstRowNull size_t alignOfData() const override { return nested_function->alignOfData(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) { @@ -278,14 +288,20 @@ class AggregateFunctionFirstRowNull if (!is_null) { const IColumn * nested_column = &column->getNestedColumn(); - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); } } } else { this->setFlag(place, 1); - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); } } @@ -380,6 +396,17 @@ class AggregateFunctionNullUnary final {} void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) { @@ -388,13 +415,19 @@ class AggregateFunctionNullUnary final { this->setFlag(place); const IColumn * nested_column = &column->getNestedColumn(); - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); } } else { this->setFlag(place); - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); } } @@ -469,6 +502,17 @@ class AggregateFunctionNullVariadic final } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { /// This container stores the columns we really pass to the nested function. const IColumn * nested_columns[number_of_arguments]; @@ -491,7 +535,10 @@ class AggregateFunctionNullVariadic final } this->setFlag(place); - this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); + if constexpr (is_add) + this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); + else + this->nested_function->decrease(this->nestedPlace(place), nested_columns, row_num, arena); } bool allocatesMemoryInArena() const override { return this->nested_function->allocatesMemoryInArena(); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h index 619eeeeab32..751930ce73b 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -110,6 +110,9 @@ class AggregateFunctionQuantile final this->data(place).add(static_cast &>(*columns[0]).getData()[row_num]); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception("");} + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index 6b2b96d60b2..9c001cd2c89 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -26,6 +26,7 @@ #include #include #include +#include "Common/Exception.h" namespace DB @@ -170,7 +171,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(time_arg)) throw Exception{ "Illegal type " + time_arg->getName() + " of first argument of aggregate function " @@ -179,7 +180,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(cond_arg)) throw Exception{ "Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) @@ -204,6 +205,12 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelperdata(place).add(timestamp, events); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override + { + throw Exception("Not implemented yet"); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); @@ -298,7 +305,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelperadd(place, columns, row_num, arena); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + nested_func->decrease(place, columns, row_num, arena); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h index 12d640180ff..93b0780cef2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h @@ -129,6 +129,9 @@ class AggregateFunctionVariance final this->data(place).update(*columns[0], row_num); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).mergeWith(this->data(rhs)); @@ -377,6 +380,9 @@ class AggregateFunctionCovariance final this->data(place).update(*columns[0], *columns[1], row_num); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).mergeWith(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index cf387284d2d..9cc7769d009 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -221,6 +221,9 @@ class AggregateFunctionVarianceSimple final this->data(place).add(static_cast &>(*columns[0]).getData()[row_num]); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index 872d42415ec..9ca48385212 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -42,10 +42,27 @@ struct AggregateFunctionSumAddImpl> } }; +template +struct AggregateFunctionSumMinusImpl +{ + static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(T & lhs, const T & rhs) { lhs -= rhs; } +}; + +template +struct AggregateFunctionSumMinusImpl> +{ + template + static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(Decimal & lhs, const Decimal & rhs) + { + lhs.value -= static_cast(rhs.value); + } +}; + template struct AggregateFunctionSumData { using Impl = AggregateFunctionSumAddImpl; + using DescreaseImpl = AggregateFunctionSumMinusImpl; T sum{}; AggregateFunctionSumData() = default; @@ -56,6 +73,12 @@ struct AggregateFunctionSumData Impl::add(sum, value); } + template + void NO_SANITIZE_UNDEFINED ALWAYS_INLINE decrease(U value) + { + DescreaseImpl::decrease(sum, value); + } + /// Vectorized version template void NO_SANITIZE_UNDEFINED NO_INLINE addMany(const Value * __restrict ptr, size_t count) @@ -165,6 +188,8 @@ struct AggregateFunctionSumKahanData void ALWAYS_INLINE add(T value) { addImpl(value, sum, compensation); } + void ALWAYS_INLINE decrease(T) { throw Exception("`decrease` function is not implemented in AggregateFunctionSumKahanData"); } + /// Vectorized version template void NO_INLINE addMany(const Value * __restrict ptr, size_t count) @@ -336,6 +361,12 @@ class AggregateFunctionSum final this->data(place).add(column.getData()[row_num]); } + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + const auto & column = assert_cast(*columns[0]); + this->data(place).decrease(column.getData()[row_num]); + } + /// Vectorized version when there is no GROUP BY keys. void addBatchSinglePlace( size_t start_offset, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index 8e3ef88f2cc..ba4e7d08db2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -134,6 +134,12 @@ class AggregateFunctionSumMap final } } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { auto & merged_maps = this->data(place).merged_maps; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h index e1e4b02e6ef..b1dbe46c12b 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h @@ -71,6 +71,9 @@ class AggregateFunctionTopK set.insert(static_cast &>(*columns[0]).getData()[row_num]); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); @@ -197,6 +200,8 @@ class AggregateFunctionTopKGeneric } } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h index a54ef4165fa..b3c1c5c6be7 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h @@ -438,6 +438,12 @@ class AggregateFunctionUniq final : public IAggregateFunctionDataHelper::add(this->data(place), *columns[0], row_num); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).set.merge(this->data(rhs).set); @@ -503,6 +509,12 @@ class AggregateFunctionUniqVariadic final UniqVariadicHash::apply(this->data(place), num_args, columns, row_num)); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).set.merge(this->data(rhs).set); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h index 371d5eece87..e81c37c8dbf 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h @@ -158,6 +158,11 @@ class AggregateFunctionUniqUpTo final this->data(place).add(*columns[0], row_num, threshold); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs), threshold); @@ -226,6 +231,9 @@ class AggregateFunctionUniqUpToVariadic final threshold); } + // TODO move to helper + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs), threshold); diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index a7927085c3b..38313f1cbe2 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -96,6 +96,11 @@ class IAggregateFunction virtual void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const = 0; + /// The purpose of this function is the opposite of `add` function + virtual void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const + = 0; + /// Merges state (on which place points to) with other state of current aggregation function. virtual void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const = 0; @@ -507,6 +512,12 @@ class IAggregateFunctionDataHelper /// NOTE: Currently not used (structures with aggregation state are put without alignment). size_t alignOfData() const override { return alignof(Data); } + // TODO uncomment it + // void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override + // { + // throw Exception("Not implemented yet"); + // } + void addBatchLookupTable8( size_t start_offset, size_t batch_size, diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index 3e4ee5b5134..887f3433267 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -246,6 +246,7 @@ WindowTransformAction::WindowTransformAction( const String & req_id) : log(Logger::get(req_id)) , window_description(window_description_) + , has_agg(false) { output_header = input_header; for (const auto & add_column : window_description_.add_columns) @@ -314,6 +315,8 @@ void WindowTransformAction::initialAggregateFunction( if (window_function_description.aggregate_function == nullptr) return; + has_agg = true; + workspace.aggregate_function = window_function_description.aggregate_function; const auto & aggregate_function = workspace.aggregate_function; if (!arena && aggregate_function->allocatesMemoryInArena()) @@ -1357,9 +1360,13 @@ void WindowTransformAction::appendBlock(Block & current_block) window_block.input_columns = current_block.getColumns(); } + // Update the aggregation states after the frame has changed. void WindowTransformAction::updateAggregationState() { + if (!has_agg) + return; + assert(frame_started); assert(frame_ended); assert(frame_start <= frame_end); @@ -1369,76 +1376,15 @@ void WindowTransformAction::updateAggregationState() assert(partition_start <= frame_start); assert(frame_end <= partition_end); - bool reset_aggregation = false; - RowNumber rows_to_add_start; - RowNumber rows_to_add_end; - if (frame_start == prev_frame_start) - { - // The frame start didn't change, add the tail rows. - reset_aggregation = false; - rows_to_add_start = prev_frame_end; - rows_to_add_end = frame_end; - } - else - { - // The frame start changed, reset the state and aggregate over the - // entire frame. This can be made per-function after we learn to - // subtract rows from some types of aggregation states, but for now we - // always have to reset when the frame start changes. - reset_aggregation = true; - rows_to_add_start = frame_start; - rows_to_add_end = frame_end; - } - for (auto & ws : workspaces) { if (ws.window_function) - continue; // No need to do anything for true window functions. - - const auto * agg_func = ws.aggregate_function.get(); - auto * buf = ws.aggregate_function_state.data(); - - if (reset_aggregation) - { - agg_func->destroy(buf); - agg_func->create(buf); - } - - // To achieve better performance, we will have to loop over blocks and - // rows manually, instead of using advanceRowNumber(). - // For this purpose, the past-the-end block can be different than the - // block of the past-the-end row (it's usually the next block). - const auto past_the_end_block = rows_to_add_end.row == 0 ? rows_to_add_end.block : rows_to_add_end.block + 1; - - for (auto block_number = rows_to_add_start.block; block_number < past_the_end_block; ++block_number) - { - auto & block = blockAt(block_number); - - if (ws.cached_block_number != block_number) - { - for (size_t i = 0; i < ws.argument_column_indices.size(); ++i) - { - ws.argument_columns[i] = block.input_columns[ws.argument_column_indices[i]].get(); - } - ws.cached_block_number = block_number; - } - - // First and last blocks may be processed partially, and other blocks - // are processed in full. - const auto first_row = block_number == rows_to_add_start.block ? rows_to_add_start.row : 0; - const auto past_the_end_row = block_number == rows_to_add_end.block ? rows_to_add_end.row : block.rows; - - // TODO Add an addBatch analog that can accept a starting offset. - // For now, add the values one by one. - auto * columns = ws.argument_columns.data(); + continue; - // Removing arena.get() from the loop makes it faster somehow... - auto * arena_ptr = arena.get(); - for (auto row = first_row; row < past_the_end_row; ++row) - { - agg_func->add(buf, columns, row, arena_ptr); - } - } + const RowNumber & end = frame_start <= prev_frame_end ? frame_start : prev_frame_end; + decreaseAggregationState(ws, prev_frame_start, end); + const RowNumber & start = frame_start <= prev_frame_end ? prev_frame_end : frame_start; + addAggregationState(ws, start, end); } } @@ -1500,6 +1446,8 @@ void WindowTransformAction::tryCalculate() assert(frame_started); assert(frame_ended); + updateAggregationState(); + // Write out the results. // TODO execute the window function by block instead of row. writeOutCurrentRow(); diff --git a/dbms/src/DataStreams/WindowBlockInputStream.h b/dbms/src/DataStreams/WindowBlockInputStream.h index d69b9ca98ed..002d4c47915 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.h +++ b/dbms/src/DataStreams/WindowBlockInputStream.h @@ -168,6 +168,59 @@ struct WindowTransformAction WindowFunctionWorkspace & workspace, const WindowFunctionDescription & window_function_description); + void addAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) + { + addOrDecreaseAggregationState(ws, start, end); + } + + void decreaseAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) + { + addOrDecreaseAggregationState(ws, start, end); + } + + template + void addOrDecreaseAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) + { + if unlikely (start == end) + return; + + const auto * agg_func = ws.aggregate_function.get(); + auto * buf = ws.aggregate_function_state.data(); + + // Used for aggregate function. + // To achieve better performance, we will have to loop over blocks and + // rows manually, instead of using advanceRowNumber(). + // For this purpose, the end block can be different than the + // block of the end row (it's usually the next block). + const auto past_the_end_block = end.row == 0 ? end.block : end.block + 1; + + for (auto block_number = frame_start.block; block_number < past_the_end_block; ++block_number) + { + auto & block = blockAt(block_number); + + if (ws.cached_block_number != block_number) + { + for (size_t i = 0; i < ws.argument_column_indices.size(); ++i) + ws.argument_columns[i] = block.input_columns[ws.argument_column_indices[i]].get(); + ws.cached_block_number = block_number; + } + + // First and last blocks may be processed partially, and other blocks are processed in full. + const auto start_row = block_number == start.block ? start.row : 0; + const auto end_row = block_number == end.block ? end.row : block.rows; + auto * columns = ws.argument_columns.data(); + + auto * arena_ptr = arena.get(); + for (auto row = start_row; row < end_row; ++row) + { + if constexpr (is_add) + agg_func->add(buf, columns, row, arena_ptr); + else + agg_func->decrease(buf, columns, row, arena_ptr); + } + } + } + void updateAggregationState(); void reinitializeAggFuncBeforeNextPartition(); @@ -189,6 +242,8 @@ struct WindowTransformAction // Per-window-function scratch spaces. std::vector workspaces; + bool has_agg; + // A sliding window of blocks we currently need. We add the input blocks as // they arrive, and discard the blocks we don't need anymore. The blocks // have an always-incrementing index. The index of the first block is in diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index bbcb6e8c7ad..844cdf2298b 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -95,7 +95,6 @@ void setFieldTypeForAggFunc( { auto * ft = expr->mutable_field_type(); ft->set_tp(TiDB::TypeLongLong); - ft->set_flag(TiDB::ColumnFlagNotNull); } else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max) { @@ -105,7 +104,6 @@ void setFieldTypeForAggFunc( auto * ft = expr->mutable_field_type(); ft->set_tp(expr->children(0).field_type().tp()); ft->set_decimal(expr->children(0).field_type().decimal()); - ft->set_flag(expr->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull)); ft->set_collate(collator_id); } else @@ -219,7 +217,8 @@ bool WindowBinder::toTiPBExecutor( return children[0]->toTiPBExecutor(window->mutable_child(), collator_id, mpp_info, context); } -void setColumnInfoForAgg( +// This function can only be used in window agg +void setColumnInfoForAggInWindow( TiDB::ColumnInfo & ci, const DB::ASTFunction * func, const std::vector & children_ci) @@ -228,7 +227,7 @@ void setColumnInfoForAgg( if (func->name == "count") { ci.tp = TiDB::TypeLongLong; - ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull; + ci.flag = TiDB::ColumnFlagUnsigned; } else if (func->name == "max" || func->name == "min" || func->name == "sum") { @@ -297,7 +296,7 @@ TiDB::ColumnInfo createColumnInfo(const DB::ASTFunction * func, const std::vecto return ci; } - setColumnInfoForAgg(ci, func, children_ci); + setColumnInfoForAggInWindow(ci, func, children_ci); return ci; } diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 83810a680d2..a4808cce43f 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -963,7 +963,11 @@ WindowDescription DAGExpressionAnalyzer::buildWindowDescription(const tipb::Wind WindowDescription window_description = createAndInitWindowDesc(this, window); setOrderByColumnTypeAndDirectionForRangeFrame(window_description, step.actions, window); - buildActionsBeforeWindow(this, window_description, chain, window); + buildActionsBeforeWindow( + this, + window_description, + chain, + window); buildActionsAfterWindow(this, window_description, chain, window, source_size); return window_description; From b011bfa67c9b0181725ccb3df342b28c86a6daf3 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 5 Dec 2024 17:06:54 +0800 Subject: [PATCH 07/32] codes done, need tests --- .../AggregateFunctions/AggregateFunctionAvg.h | 13 + .../AggregateFunctionCount.h | 9 + .../AggregateFunctionGroupArray.h | 3 - .../AggregateFunctionMinMaxAny.cpp | 2 +- .../AggregateFunctionMinMaxAny.h | 229 +++++++++++++++++- .../AggregateFunctions/AggregateFunctionSum.h | 13 + .../AggregateFunctions/IAggregateFunction.h | 29 ++- .../DataStreams/WindowBlockInputStream.cpp | 2 + 8 files changed, 282 insertions(+), 18 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index d4a069d2006..5ff7c6508e4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -29,6 +29,12 @@ struct AggregateFunctionAvgData T sum; UInt64 count; + void reset() + { + sum = 0; + count = 0; + } + AggregateFunctionAvgData() : sum(0) , count(0) @@ -67,6 +73,8 @@ class AggregateFunctionAvg final return std::make_shared(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (IsDecimal) @@ -89,6 +97,11 @@ class AggregateFunctionAvg final --this->data(place).count; } + void reset(AggregateDataPtr __restrict place) const override + { + this->data(place).reset(); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).sum += this->data(rhs).sum; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index b93c4bdab03..7d95c88d04b 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -29,6 +29,8 @@ namespace DB struct AggregateFunctionCountData { UInt64 count = 0; + + void reset() { count = 0; } }; namespace ErrorCodes @@ -235,6 +237,8 @@ class AggregateFunctionCountNotNullVariadic final DataTypePtr getReturnType() const override { return std::make_shared(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { for (size_t i = 0; i < number_of_arguments; ++i) @@ -253,6 +257,11 @@ class AggregateFunctionCountNotNullVariadic final --data(place).count; } + void reset(AggregateDataPtr __restrict place) const override + { + this->data(place).reset(); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index 815b0baf275..c2745d96bb8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -287,9 +287,6 @@ class GroupArrayGeneralListImpl final ++data(place).elems; } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { /// It is sadly, but rhs's Arena could be destroyed diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp index 34d56c0f4a4..84ba1a12853 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp @@ -124,7 +124,7 @@ AggregateFunctionPtr createAggregateFunctionArgMax( void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory) { factory.registerFunction("any", createAggregateFunctionAny); // TODO no use - factory.registerFunction("first_row", createAggregateFunctionFirstRow); + factory.registerFunction("first_row", createAggregateFunctionFirstRow); // TODO not used in window agg factory.registerFunction("anyLast", createAggregateFunctionAnyLast); // TODO no use factory.registerFunction("anyHeavy", createAggregateFunctionAnyHeavy); // TODO no use factory.registerFunction("min", createAggregateFunctionMin, AggregateFunctionFactory::CaseInsensitive); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index cb04bab9aa2..2fc17f2d915 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -23,6 +23,8 @@ #include #include #include +#include +#include namespace DB @@ -31,6 +33,7 @@ namespace DB * For example: min, max, any, anyLast. */ +// TODO maybe we can create a new class to be inherited by SingleValueDataFixed, SingleValueDataString and SingleValueDataGeneric /// For numeric values. template @@ -43,6 +46,9 @@ struct SingleValueDataFixed = false; /// We need to remember if at least one value has been passed. This is necessary for AggregateFunctionIf. T value; + // It's only used in window aggregation + mutable std::unique_ptr> saved_values; + using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; public: @@ -72,11 +78,70 @@ struct SingleValueDataFixed readBinary(value, buf); } + void insertMaxResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } + + void insertMinResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } + + template + void insertMinOrMaxResultInto(IColumn & to) + { + if (has()) + { + auto size = saved_values->size(); + value = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < value) + value = (*saved_values)[i]; + } + else + { + if (value < (*saved_values)[i]) + value = (*saved_values)[i]; + } + } + static_cast(to).getData().push_back(value); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() + { + saved_values = std::make_unique>(); + } + + void reset() + { + has_value = false; + saved_values->clear(); + } + + // Only used for window aggregation + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + has_value = false; + } void change(const IColumn & column, size_t row_num, Arena *) { has_value = true; value = static_cast(column).getData()[row_num]; + + if (saved_values) + saved_values->push_back(value); } /// Assuming to.has() @@ -84,6 +149,8 @@ struct SingleValueDataFixed { has_value = true; value = to.value; + if (saved_values) + saved_values->push_back(value); } bool changeFirstTime(const IColumn & column, size_t row_num, Arena * arena) @@ -191,6 +258,10 @@ struct SingleValueDataString char * large_data{}; TiDB::TiDBCollatorPtr collator{}; + // TODO use std::string is inefficient + // It's only used in window aggregation + mutable std::unique_ptr> saved_values; + bool less(const StringRef & a, const StringRef & b) const { if (unlikely(collator == nullptr)) @@ -213,9 +284,9 @@ struct SingleValueDataString } public: - static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; + static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 72; static constexpr Int32 MAX_SMALL_STRING_SIZE - = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(collator); + = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::unique_ptr>); private: char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. @@ -290,6 +361,65 @@ struct SingleValueDataString } } + void insertMaxResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } + + void insertMinResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } + + template + void insertMinOrMaxResultInto(IColumn & to) + { + if (has()) + { + auto elem_num = saved_values->size(); + StringRef value((*saved_values)[0].c_str(), (*saved_values)[0].size()); + for (size_t i = 1; i < elem_num; i++) + { + String cmp_value((*saved_values)[i].c_str(), (*saved_values)[i].size()); + if constexpr (is_min) + { + if (less(cmp_value, value)) + value = (*saved_values)[i]; + } + else + { + if (less(value, cmp_value)) + value = (*saved_values)[i]; + } + } + + static_cast(to).insertDataWithTerminatingZero(getData(), size); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() + { + saved_values = std::make_unique>(); + } + + void reset() + { + size = -1; + saved_values->clear(); + } + + // Only used for window aggregation + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + size = -1; + } + /// Assuming to.has() void changeImpl(StringRef value, Arena * arena) { @@ -302,6 +432,9 @@ struct SingleValueDataString if (size > 0) memcpy(small_data, value.data, size); + + if (saved_values) + saved_values->push_back(std::string(small_data, size)); } else { @@ -314,6 +447,9 @@ struct SingleValueDataString size = value_size; memcpy(large_data, value.data, size); + + if (saved_values) + saved_values->push_back(std::string(large_data, size)); } } @@ -432,6 +568,9 @@ struct SingleValueDataGeneric Field value; + // It's only used in window aggregation + std::unique_ptr> saved_values; + public: bool has() const { return !value.isNull(); } @@ -465,9 +604,76 @@ struct SingleValueDataGeneric data_type.deserializeBinary(value, buf); } - void change(const IColumn & column, size_t row_num, Arena *) { column.get(row_num, value); } + void insertMaxResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } - void change(const Self & to, Arena *) { value = to.value; } + void insertMinResultInto(IColumn & to) + { + insertMinOrMaxResultInto(to); + } + + template + void insertMinOrMaxResultInto(IColumn & to) + { + if (has()) + { + auto size = saved_values->size(); + value = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < value) + value = (*saved_values)[i]; + } + else + { + if (value < (*saved_values)[i]) + value = (*saved_values)[i]; + } + } + to.insert(value); + } + else + { + to.insertDefault(); + } + } + + void prepareWindow() + { + saved_values = std::make_unique>(); + } + + void reset() + { + value = Field(); + saved_values->clear(); + } + + // Only used for window aggregation + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + value = Field(); + } + + void change(const IColumn & column, size_t row_num, Arena *) + { + column.get(row_num, value); + if (saved_values) + saved_values->push_back(value); + } + + void change(const Self & to, Arena *) + { + value = to.value; + if (saved_values) + saved_values->push_back(value); + } bool changeFirstTime(const IColumn & column, size_t row_num, Arena * arena) { @@ -751,10 +957,19 @@ class AggregateFunctionsSingleValue final this->data(place).changeIfBetter(*columns[0], row_num, arena); } - // TODO implement decrease - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + void prepareWindow(AggregateDataPtr __restrict place) const override + { + this->data(place).prepareWindow(); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override + { + this->data(place).decrease(); + } + + void reset(AggregateDataPtr __restrict place) const override { - throw Exception(""); + this->data(place).reset(); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index 9ca48385212..fa4003f3443 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -79,6 +79,12 @@ struct AggregateFunctionSumData DescreaseImpl::decrease(sum, value); } + template + void NO_SANITIZE_UNDEFINED ALWAYS_INLINE reset() + { + sum = 0; + } + /// Vectorized version template void NO_SANITIZE_UNDEFINED NO_INLINE addMany(const Value * __restrict ptr, size_t count) @@ -355,6 +361,8 @@ class AggregateFunctionSum final new (place) Data; } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { const auto & column = assert_cast(*columns[0]); @@ -367,6 +375,11 @@ class AggregateFunctionSum final this->data(place).decrease(column.getData()[row_num]); } + void reset(AggregateDataPtr __restrict place) const override + { + this->data(place).reset(); + } + /// Vectorized version when there is no GROUP BY keys. void addBatchSinglePlace( size_t start_offset, diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index 38313f1cbe2..7609737740d 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -96,11 +96,17 @@ class IAggregateFunction virtual void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const = 0; - /// The purpose of this function is the opposite of `add` function + /// The purpose of this function is the opposite of `add` function, only used in window function. virtual void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const = 0; + // Some window aggregation functions need to prepare something before execution + virtual void prepareWindow(AggregateDataPtr __restrict) const = 0; + + // Only used in window aggregation functions + virtual void reset(AggregateDataPtr __restrict) const = 0; + /// Merges state (on which place points to) with other state of current aggregation function. virtual void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const = 0; @@ -375,6 +381,21 @@ class IAggregateFunctionHelper : public IAggregateFunction static_cast(this)->add(place + place_offset, columns, i, arena); } } + + void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override + { + throw Exception("Not implemented yet"); + } + + void prepareWindow(AggregateDataPtr __restrict) const override + { + throw Exception("Not implemented yet"); + } + + void reset(AggregateDataPtr __restrict) const override + { + throw Exception("Not implemented yet"); + } }; namespace _IAggregateFunctionImpl @@ -512,12 +533,6 @@ class IAggregateFunctionDataHelper /// NOTE: Currently not used (structures with aggregation state are put without alignment). size_t alignOfData() const override { return alignof(Data); } - // TODO uncomment it - // void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override - // { - // throw Exception("Not implemented yet"); - // } - void addBatchLookupTable8( size_t start_offset, size_t batch_size, diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index 887f3433267..de80f7d3a4c 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -1381,6 +1381,8 @@ void WindowTransformAction::updateAggregationState() if (ws.window_function) continue; + // TODO compare the decrease and add number in previous frame + // when decrease > add, we create a new agg data to recalculate from start. const RowNumber & end = frame_start <= prev_frame_end ? frame_start : prev_frame_end; decreaseAggregationState(ws, prev_frame_start, end); const RowNumber & start = frame_start <= prev_frame_end ? prev_frame_end : frame_start; From 8628a2ece1cda836460d5d79bd9d9de18a2c7049 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 12 Dec 2024 15:06:18 +0800 Subject: [PATCH 08/32] save --- .../tests/gtest_window_agg.cpp | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp new file mode 100644 index 00000000000..e8c9dfe04e6 --- /dev/null +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -0,0 +1,67 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +namespace DB +{ +namespace tests +{ + +class ExecutorWindowAgg : public DB::tests::AggregationTest {}; + +TEST_F(ExecutorWindowAgg, Avg) +try +{ + auto data_type_uint8 = std::make_shared(); + auto context = TiFlashTestEnv::getContext(); + auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {data_type_uint8}, {}, 0, true); + // result_type = window_function_description.aggregate_function->getReturnType(); +} +CATCH + +TEST_F(ExecutorWindowAgg, Count) +try +{ + +} +CATCH + +TEST_F(ExecutorWindowAgg, Sum) +try +{ + +} +CATCH + +TEST_F(ExecutorWindowAgg, Min) +try +{ + +} +CATCH + +TEST_F(ExecutorWindowAgg, Max) +try +{ + +} +CATCH + +} // namespace tests +} // namespace DB From db4e3994bd4876ee540537decabcf498e49abbeb Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 13 Dec 2024 20:10:43 +0800 Subject: [PATCH 09/32] add tests --- .../AggregateFunctionArray.h | 2 + .../AggregateFunctions/AggregateFunctionAvg.h | 7 +- .../AggregateFunctionCount.h | 11 +- .../AggregateFunctionForEach.h | 26 +++- .../AggregateFunctionGroupArray.h | 5 +- .../AggregateFunctionGroupConcat.h | 7 +- .../AggregateFunctionGroupUniqArray.h | 6 +- .../AggregateFunctions/AggregateFunctionIf.h | 2 + .../AggregateFunctionMerge.h | 4 +- .../AggregateFunctionMinMaxAny.h | 74 +++------ .../AggregateFunctionNull.h | 25 +++- .../AggregateFunctionQuantile.h | 5 +- .../AggregateFunctionSequenceMatch.h | 9 +- .../AggregateFunctionState.h | 2 + .../AggregateFunctionStatistics.h | 10 +- .../AggregateFunctionStatisticsSimple.h | 5 +- .../AggregateFunctions/AggregateFunctionSum.h | 17 +-- .../AggregateFunctionTopK.h | 10 +- .../AggregateFunctionUniqUpTo.h | 5 +- .../AggregateFunctions/IAggregateFunction.h | 10 +- .../tests/gtest_window_agg.cpp | 140 +++++++++++++++--- .../Coprocessor/DAGExpressionAnalyzer.cpp | 6 +- 22 files changed, 251 insertions(+), 137 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index 6439ff91688..9297f7ad543 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -110,6 +110,8 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperdecrease(place, nested, i, arena); } + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index 5ff7c6508e4..1e13affca86 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -31,7 +31,7 @@ struct AggregateFunctionAvgData void reset() { - sum = 0; + sum = T(0); count = 0; } @@ -97,10 +97,7 @@ class AggregateFunctionAvg final --this->data(place).count; } - void reset(AggregateDataPtr __restrict place) const override - { - this->data(place).reset(); - } + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index 7d95c88d04b..b474c7b5b49 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -30,7 +30,7 @@ struct AggregateFunctionCountData { UInt64 count = 0; - void reset() { count = 0; } + inline void reset() noexcept { count = 0; } }; namespace ErrorCodes @@ -59,6 +59,8 @@ class AggregateFunctionCount final --data(place).count; } + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + void addBatchSinglePlace( size_t start_offset, size_t batch_size, @@ -185,6 +187,8 @@ class AggregateFunctionCountNotNullUnary final data(place).count -= !static_cast(*columns[0]).isNullAt(row_num); } + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; @@ -257,10 +261,7 @@ class AggregateFunctionCountNotNullVariadic final --data(place).count; } - void reset(AggregateDataPtr __restrict place) const override - { - this->data(place).reset(); - } + void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h index b074bfe5d75..e718077e2ea 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h @@ -66,7 +66,7 @@ class AggregateFunctionForEach final AggregateFunctionForEachData & ensureAggregateData( AggregateDataPtr __restrict place, size_t new_size, - Arena & arena) const + Arena * arena) const { AggregateFunctionForEachData & state = data(place); @@ -75,7 +75,10 @@ class AggregateFunctionForEach final size_t old_size = state.dynamic_array_size; if (old_size < new_size) { - state.array_of_aggregate_datas = arena.realloc( + if unlikely (arena == nullptr) + throw Exception("Get nullptr in ensureAggregateData"); + + state.array_of_aggregate_datas = arena->realloc( state.array_of_aggregate_datas, old_size * nested_size_of_data, new_size * nested_size_of_data); @@ -185,7 +188,7 @@ class AggregateFunctionForEach final ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } - AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena); + AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, arena); char * nested_state = state.array_of_aggregate_datas; for (size_t i = begin; i < end; ++i) @@ -198,10 +201,21 @@ class AggregateFunctionForEach final } } + void reset(AggregateDataPtr __restrict place) const override + { + AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); + char * nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; i++) + { + nested_func->reset(nested_state); + nested_state += nested_size_of_data; + } + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { const AggregateFunctionForEachData & rhs_state = data(rhs); - AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena); + AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, arena); const char * rhs_nested_state = rhs_state.array_of_aggregate_datas; char * nested_state = state.array_of_aggregate_datas; @@ -235,7 +249,7 @@ class AggregateFunctionForEach final size_t new_size = 0; readBinary(new_size, buf); - ensureAggregateData(place, new_size, *arena); + ensureAggregateData(place, new_size, arena); char * nested_state = state.array_of_aggregate_datas; for (size_t i = 0; i < new_size; ++i) @@ -249,7 +263,7 @@ class AggregateFunctionForEach final { const AggregateFunctionForEachData & state = data(place); - ColumnArray & arr_to = static_cast(to); + auto & arr_to = static_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & elems_to = arr_to.getData(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index c2745d96bb8..7910783e143 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -81,7 +81,10 @@ class GroupArrayNumericImpl final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h index 28602c01766..0dafef50a57 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h @@ -117,7 +117,7 @@ class AggregateFunctionGroupConcat final DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; } /// reject nulls before add()/decrease() of nested agg - template + template void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (only_one_column) @@ -162,13 +162,14 @@ class AggregateFunctionGroupConcat final else this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); } - + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { addOrDecrease(place, columns, row_num, arena); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override { addOrDecrease(place, columns, row_num, arena); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index 74d2beb01c6..872176d3f08 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -63,7 +63,8 @@ class AggregateFunctionGroupUniqArray void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & key = AggregateFunctionGroupUniqArrayData::Set::Cell::getKey(assert_cast &>(*columns[0]).getData()[row_num]); + const auto & key = AggregateFunctionGroupUniqArrayData::Set::Cell::getKey( + assert_cast &>(*columns[0]).getData()[row_num]); this->data(place).value.erase(key); } @@ -177,7 +178,8 @@ class AggregateFunctionGroupUniqArrayGeneric set.emplace(key_holder, it, inserted); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override { auto & set = this->data(place).value; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index e8baf4114aa..45bf8092c21 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -83,6 +83,8 @@ class AggregateFunctionIf final : public IAggregateFunctionHelperdecrease(place, columns, row_num, arena); } + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + void addBatch( size_t start_offset, size_t batch_size, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 1f04e87edb0..32c87230f01 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -37,7 +37,7 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelper(&argument); + const auto * data_type = typeid_cast(&argument); if (!data_type || data_type->getFunctionName() != nested_func->getName()) throw Exception( @@ -70,6 +70,8 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelpermerge(place, static_cast(*columns[0]).getData()[row_num], arena); } + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 2fc17f2d915..63c39ac9613 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -23,6 +23,7 @@ #include #include #include + #include #include @@ -78,17 +79,11 @@ struct SingleValueDataFixed readBinary(value, buf); } - void insertMaxResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - template + template void insertMinOrMaxResultInto(IColumn & to) { if (has()) @@ -102,7 +97,7 @@ struct SingleValueDataFixed if ((*saved_values)[i] < value) value = (*saved_values)[i]; } - else + else { if (value < (*saved_values)[i]) value = (*saved_values)[i]; @@ -116,10 +111,7 @@ struct SingleValueDataFixed } } - void prepareWindow() - { - saved_values = std::make_unique>(); - } + void prepareWindow() { saved_values = std::make_unique>(); } void reset() { @@ -285,8 +277,8 @@ struct SingleValueDataString public: static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 72; - static constexpr Int32 MAX_SMALL_STRING_SIZE - = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::unique_ptr>); + static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) + - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::unique_ptr>); private: char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. @@ -361,17 +353,11 @@ struct SingleValueDataString } } - void insertMaxResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - template + template void insertMinOrMaxResultInto(IColumn & to) { if (has()) @@ -386,7 +372,7 @@ struct SingleValueDataString if (less(cmp_value, value)) value = (*saved_values)[i]; } - else + else { if (less(value, cmp_value)) value = (*saved_values)[i]; @@ -401,10 +387,7 @@ struct SingleValueDataString } } - void prepareWindow() - { - saved_values = std::make_unique>(); - } + void prepareWindow() { saved_values = std::make_unique>(); } void reset() { @@ -604,17 +587,11 @@ struct SingleValueDataGeneric data_type.deserializeBinary(value, buf); } - void insertMaxResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) - { - insertMinOrMaxResultInto(to); - } + void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } - template + template void insertMinOrMaxResultInto(IColumn & to) { if (has()) @@ -628,7 +605,7 @@ struct SingleValueDataGeneric if ((*saved_values)[i] < value) value = (*saved_values)[i]; } - else + else { if (value < (*saved_values)[i]) value = (*saved_values)[i]; @@ -642,11 +619,8 @@ struct SingleValueDataGeneric } } - void prepareWindow() - { - saved_values = std::make_unique>(); - } - + void prepareWindow() { saved_values = std::make_unique>(); } + void reset() { value = Field(); @@ -957,20 +931,14 @@ class AggregateFunctionsSingleValue final this->data(place).changeIfBetter(*columns[0], row_num, arena); } - void prepareWindow(AggregateDataPtr __restrict place) const override - { - this->data(place).prepareWindow(); - } + void prepareWindow(AggregateDataPtr __restrict place) const override { this->data(place).prepareWindow(); } void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override { this->data(place).decrease(); } - void reset(AggregateDataPtr __restrict place) const override - { - this->data(place).reset(); - } + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index d0a824bc030..1beadc9634d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -270,12 +270,13 @@ class AggregateFunctionFirstRowNull addOrDecrease(place, columns, row_num, arena); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override { addOrDecrease(place, columns, row_num, arena); } - template + template void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) @@ -400,12 +401,13 @@ class AggregateFunctionNullUnary final addOrDecrease(place, columns, row_num, arena); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override { addOrDecrease(place, columns, row_num, arena); } - template + template void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) @@ -431,6 +433,11 @@ class AggregateFunctionNullUnary final } } + void reset(AggregateDataPtr __restrict place) const override + { + this->nested_function->reset(this->nestedPlace(place)); + } + void addBatchSinglePlace( // NOLINT(google-default-arguments) size_t start_offset, size_t batch_size, @@ -506,12 +513,13 @@ class AggregateFunctionNullVariadic final addOrDecrease(place, columns, row_num, arena); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override { addOrDecrease(place, columns, row_num, arena); } - template + template void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { /// This container stores the columns we really pass to the nested function. @@ -541,6 +549,11 @@ class AggregateFunctionNullVariadic final this->nested_function->decrease(this->nestedPlace(place), nested_columns, row_num, arena); } + void reset(AggregateDataPtr __restrict place) const override + { + this->nested_function->reset(this->nestedPlace(place)); + } + bool allocatesMemoryInArena() const override { return this->nested_function->allocatesMemoryInArena(); } private: diff --git a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h index 751930ce73b..cf0aea645f2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -111,7 +111,10 @@ class AggregateFunctionQuantile final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception("");} + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index 9c001cd2c89..bbd8e9fe686 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -26,6 +26,7 @@ #include #include #include + #include "Common/Exception.h" @@ -171,7 +172,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(time_arg)) throw Exception{ "Illegal type " + time_arg->getName() + " of first argument of aggregate function " @@ -180,7 +181,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper(cond_arg)) throw Exception{ "Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) @@ -305,7 +306,7 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelperdecrease(place, columns, row_num, arena); } + void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h index 93b0780cef2..87610e3e7eb 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h @@ -130,7 +130,10 @@ class AggregateFunctionVariance final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { @@ -381,7 +384,10 @@ class AggregateFunctionCovariance final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index 9cc7769d009..8db4e6c3133 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -222,7 +222,10 @@ class AggregateFunctionVarianceSimple final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index fa4003f3443..4025153008e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -79,11 +79,7 @@ struct AggregateFunctionSumData DescreaseImpl::decrease(sum, value); } - template - void NO_SANITIZE_UNDEFINED ALWAYS_INLINE reset() - { - sum = 0; - } + void NO_SANITIZE_UNDEFINED ALWAYS_INLINE reset() { sum = T(0); } /// Vectorized version template @@ -194,7 +190,9 @@ struct AggregateFunctionSumKahanData void ALWAYS_INLINE add(T value) { addImpl(value, sum, compensation); } - void ALWAYS_INLINE decrease(T) { throw Exception("`decrease` function is not implemented in AggregateFunctionSumKahanData"); } + void ALWAYS_INLINE decrease(T) { throw Exception("Not implemented yet"); } + + void ALWAYS_INLINE reset() { throw Exception("Not implemented yet"); } /// Vectorized version template @@ -342,7 +340,7 @@ class AggregateFunctionSum final AggregateFunctionSum(PrecType prec, ScaleType scale) { std::tie(result_prec, result_scale) = Name::decimalInfer(prec, scale); - }; + } DataTypePtr getReturnType() const override { @@ -375,10 +373,7 @@ class AggregateFunctionSum final this->data(place).decrease(column.getData()[row_num]); } - void reset(AggregateDataPtr __restrict place) const override - { - this->data(place).reset(); - } + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } /// Vectorized version when there is no GROUP BY keys. void addBatchSinglePlace( diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h index b1dbe46c12b..cd6ad86e08f 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h @@ -72,7 +72,10 @@ class AggregateFunctionTopK } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { @@ -200,7 +203,10 @@ class AggregateFunctionTopKGeneric } } - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h index e81c37c8dbf..6fa62a366be 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h @@ -232,7 +232,10 @@ class AggregateFunctionUniqUpToVariadic final } // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { throw Exception(""); } + void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override + { + throw Exception(""); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index 7609737740d..2f70766c495 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -387,15 +387,9 @@ class IAggregateFunctionHelper : public IAggregateFunction throw Exception("Not implemented yet"); } - void prepareWindow(AggregateDataPtr __restrict) const override - { - throw Exception("Not implemented yet"); - } + void prepareWindow(AggregateDataPtr __restrict) const override { throw Exception("Not implemented yet"); } - void reset(AggregateDataPtr __restrict) const override - { - throw Exception("Not implemented yet"); - } + void reset(AggregateDataPtr __restrict) const override { throw Exception("Not implemented yet"); } }; namespace _IAggregateFunctionImpl diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index e8c9dfe04e6..f13579a4ded 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -12,55 +12,155 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include +#include +#include +#include +#include #include #include #include +#include +#include +#include namespace DB { namespace tests { +class ExecutorWindowAgg : public DB::tests::AggregationTest +{ +public: + void SetUp() override { dre = std::default_random_engine(r()); } -class ExecutorWindowAgg : public DB::tests::AggregationTest {}; +private: + // range: [begin, end] + inline UInt32 rand(UInt32 begin, UInt32 end) noexcept + { + di.param(std::uniform_int_distribution::param_type{begin, end}); + return di(dre); + } -TEST_F(ExecutorWindowAgg, Avg) +protected: + inline UInt32 getResetNum() noexcept { return rand(1, 10); } + inline UInt32 getAddNum() noexcept { return rand(1, 5); } + inline UInt32 getDecreaseNum(UInt32 max_num) noexcept { return rand(0, max_num); } + inline UInt32 getRowIdx(UInt32 start, UInt32 end) noexcept { return rand(start, end); } + inline UInt32 getResultNum() noexcept { return rand(1, 3); } + + std::random_device r; + std::default_random_engine dre; + std::uniform_int_distribution di; + + static const ColumnPtr input_int_col; + static const ColumnPtr input_decimal64_col; + static const ColumnPtr input_decimal256_col; + + static DataTypePtr type_int; + static DataTypePtr type_decimal64; + static DataTypePtr type_decimal256; +}; + +const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; +const std::vector input_decimal_vec{"0.12", "0", "1.11", "0", "-0.23", "0", "0", "-0.98", "1.21"}; + +const ColumnPtr ExecutorWindowAgg::input_int_col = createColumn(input_int_vec).column; +const ColumnPtr ExecutorWindowAgg::input_decimal64_col + = createColumn(std::make_tuple(9, 4), input_decimal_vec).column; +const ColumnPtr ExecutorWindowAgg::input_decimal256_col + = createColumn(std::make_tuple(9, 4), input_decimal_vec).column; + +DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); +DataTypePtr ExecutorWindowAgg::type_decimal64 = std::make_shared(); +DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(); + +TEST_F(ExecutorWindowAgg, Sum) try { - auto data_type_uint8 = std::make_shared(); + Arena arena; auto context = TiFlashTestEnv::getContext(); - auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {data_type_uint8}, {}, 0, true); - // result_type = window_function_description.aggregate_function->getReturnType(); + auto agg_func + = AggregateFunctionFactory::instance().get(*context, "sum", {ExecutorWindowAgg::type_int}, {}, 0, true); + AlignedBuffer agg_state; + agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); + agg_func->create(agg_state.data()); + + std::deque added_row_idx_queue; + std::vector res_int; + res_int.reserve(10); + + const UInt32 col_int_size = input_int_vec.size(); + + UInt32 reset_num = getResetNum(); + auto res_col = ColumnInt64::create(); + auto null_map_col = ColumnUInt8::create(); + auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); + const IColumn * int_col = &(*ExecutorWindowAgg::input_int_col); + + { + // Test for int + for (UInt32 i = 0; i < reset_num; i++) + { + Int64 res = 0; + agg_func->reset(agg_state.data()); + + const UInt32 res_num = getResultNum(); + for (UInt32 j = 0; j < res_num; j++) + { + const UInt32 add_num = getAddNum(); + for (UInt32 k = 0; k < add_num; k++) + { + const UInt32 row_idx = getRowIdx(0, col_int_size - 1); + added_row_idx_queue.push_back(row_idx); + agg_func->add(agg_state.data(), &int_col, row_idx, &arena); + res += input_int_vec[row_idx]; + } + + const UInt32 decrease_num = getDecreaseNum(add_num); // todo change to length of deque + for (UInt32 k = 0; k < decrease_num; k++) + { + const UInt32 row_idx = added_row_idx_queue.front(); + added_row_idx_queue.pop_front(); + agg_func->decrease(agg_state.data(), &int_col, row_idx, &arena); + res -= input_int_vec[row_idx]; + } + + agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); + res_int.push_back(res); + } + } + + const auto nested_res_col = null_res_col->getNestedColumnPtr(); + size_t res_num = res_int.size(); + ASSERT_EQ(res_num, null_res_col->size()); + for (size_t i = 0; i < res_num; i++) + { + ASSERT_FALSE(null_res_col->isNullAt(i)); + ASSERT_EQ(res_int[i], nested_res_col->getInt(i)); + } + } } CATCH TEST_F(ExecutorWindowAgg, Count) try -{ - -} +{} CATCH -TEST_F(ExecutorWindowAgg, Sum) +TEST_F(ExecutorWindowAgg, Avg) try -{ - -} +{} CATCH TEST_F(ExecutorWindowAgg, Min) try -{ - -} +{} CATCH TEST_F(ExecutorWindowAgg, Max) try -{ - -} +{} CATCH } // namespace tests diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index a4808cce43f..83810a680d2 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -963,11 +963,7 @@ WindowDescription DAGExpressionAnalyzer::buildWindowDescription(const tipb::Wind WindowDescription window_description = createAndInitWindowDesc(this, window); setOrderByColumnTypeAndDirectionForRangeFrame(window_description, step.actions, window); - buildActionsBeforeWindow( - this, - window_description, - chain, - window); + buildActionsBeforeWindow(this, window_description, chain, window); buildActionsAfterWindow(this, window_description, chain, window, source_size); return window_description; From a9360da9cdacd1e5ff3d5b539499b4ac8195bfc6 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 17 Dec 2024 10:53:19 +0800 Subject: [PATCH 10/32] add sum tests --- .../tests/gtest_window_agg.cpp | 166 +++++++++++++++--- 1 file changed, 137 insertions(+), 29 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index f13579a4ded..9022ceb943c 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,20 @@ namespace DB { namespace tests { +constexpr int scale = 2; +const std::vector input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32"}; +const std::vector input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532"}; +const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; +const std::vector input_decimal_vec{ + Decimal256(std::stoi(input_string_vec_aux[0])), + Decimal256(std::stoi(input_string_vec_aux[1])), + Decimal256(std::stoi(input_string_vec_aux[2])), + Decimal256(std::stoi(input_string_vec_aux[3])), + Decimal256(std::stoi(input_string_vec_aux[4])), + Decimal256(std::stoi(input_string_vec_aux[5])), + Decimal256(std::stoi(input_string_vec_aux[6])), + Decimal256(std::stoi(input_string_vec_aux[7]))}; + class ExecutorWindowAgg : public DB::tests::AggregationTest { public: @@ -43,6 +58,33 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest } protected: + static const IColumn * getInputColumn(const IDataType * type) + { + if (const auto * tmp = dynamic_cast(type); tmp != nullptr) + return &(*ExecutorWindowAgg::input_int_col); + else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) + return &(*ExecutorWindowAgg::input_decimal128_col); + else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) + return &(*ExecutorWindowAgg::input_decimal256_col); + else + throw Exception("Invalid data type"); + } + + static String getValue(const Field & field) + { + switch (field.getType()) + { + case Field::Types::Which::Int64: + return std::to_string(field.template get()); + case Field::Types::Which::Decimal128: + return field.template get>().toString(); + case Field::Types::Which::Decimal256: + return field.template get>().toString(); + default: + throw Exception("Invalid data type"); + } + } + inline UInt32 getResetNum() noexcept { return rand(1, 10); } inline UInt32 getAddNum() noexcept { return rand(1, 5); } inline UInt32 getDecreaseNum(UInt32 max_num) noexcept { return rand(0, max_num); } @@ -54,52 +96,51 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest std::uniform_int_distribution di; static const ColumnPtr input_int_col; - static const ColumnPtr input_decimal64_col; + static const ColumnPtr input_decimal128_col; static const ColumnPtr input_decimal256_col; static DataTypePtr type_int; - static DataTypePtr type_decimal64; + static DataTypePtr type_decimal128; static DataTypePtr type_decimal256; }; -const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; -const std::vector input_decimal_vec{"0.12", "0", "1.11", "0", "-0.23", "0", "0", "-0.98", "1.21"}; - const ColumnPtr ExecutorWindowAgg::input_int_col = createColumn(input_int_vec).column; -const ColumnPtr ExecutorWindowAgg::input_decimal64_col - = createColumn(std::make_tuple(9, 4), input_decimal_vec).column; +const ColumnPtr ExecutorWindowAgg::input_decimal128_col + = createColumn(std::make_tuple(10, scale), input_string_vec).column; const ColumnPtr ExecutorWindowAgg::input_decimal256_col - = createColumn(std::make_tuple(9, 4), input_decimal_vec).column; + = createColumn(std::make_tuple(30, scale), input_string_vec).column; DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); -DataTypePtr ExecutorWindowAgg::type_decimal64 = std::make_shared(); -DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(); +DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, scale); +DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, scale); TEST_F(ExecutorWindowAgg, Sum) try { Arena arena; auto context = TiFlashTestEnv::getContext(); - auto agg_func - = AggregateFunctionFactory::instance().get(*context, "sum", {ExecutorWindowAgg::type_int}, {}, 0, true); - AlignedBuffer agg_state; - agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); - agg_func->create(agg_state.data()); - std::deque added_row_idx_queue; - std::vector res_int; - res_int.reserve(10); - - const UInt32 col_int_size = input_int_vec.size(); - - UInt32 reset_num = getResetNum(); - auto res_col = ColumnInt64::create(); - auto null_map_col = ColumnUInt8::create(); - auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); - const IColumn * int_col = &(*ExecutorWindowAgg::input_int_col); { // Test for int + added_row_idx_queue.clear(); + auto agg_func + = AggregateFunctionFactory::instance().get(*context, "sum", {ExecutorWindowAgg::type_int}, {}, 0, true); + AlignedBuffer agg_state; + agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); + agg_func->create(agg_state.data()); + + std::vector res_int; + res_int.reserve(10); + + const UInt32 col_int_size = input_int_vec.size(); + + UInt32 reset_num = getResetNum(); + auto res_col = ExecutorWindowAgg::type_int->createColumn(); + auto null_map_col = ColumnUInt8::create(); + auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); + const IColumn * input_col = getInputColumn(ExecutorWindowAgg::type_int.get()); + for (UInt32 i = 0; i < reset_num; i++) { Int64 res = 0; @@ -113,16 +154,16 @@ try { const UInt32 row_idx = getRowIdx(0, col_int_size - 1); added_row_idx_queue.push_back(row_idx); - agg_func->add(agg_state.data(), &int_col, row_idx, &arena); + agg_func->add(agg_state.data(), &input_col, row_idx, &arena); res += input_int_vec[row_idx]; } - const UInt32 decrease_num = getDecreaseNum(add_num); // todo change to length of deque + const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); for (UInt32 k = 0; k < decrease_num; k++) { const UInt32 row_idx = added_row_idx_queue.front(); added_row_idx_queue.pop_front(); - agg_func->decrease(agg_state.data(), &int_col, row_idx, &arena); + agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); res -= input_int_vec[row_idx]; } @@ -140,6 +181,73 @@ try ASSERT_EQ(res_int[i], nested_res_col->getInt(i)); } } + + { + // Test for decimal + DataTypes test_types{type_decimal128, type_decimal256}; + for (const auto & type : test_types) + { + added_row_idx_queue.clear(); + auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {type}, {}, 0, true); + AlignedBuffer agg_state; + agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); + agg_func->create(agg_state.data()); + + std::vector res_vec; + res_vec.reserve(10); + + const UInt32 col_decimal_size = input_decimal_vec.size(); + + UInt32 reset_num = getResetNum(); + auto res_col = type->createColumn(); + auto null_map_col = ColumnUInt8::create(); + auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); + const IColumn * input_col = getInputColumn(type.get()); + + for (UInt32 i = 0; i < reset_num; i++) + { + Decimal256 res(0); + agg_func->reset(agg_state.data()); + + const UInt32 res_num = getResultNum(); + for (UInt32 j = 0; j < res_num; j++) + { + const UInt32 add_num = getAddNum(); + for (UInt32 k = 0; k < add_num; k++) + { + const UInt32 row_idx = getRowIdx(0, col_decimal_size - 1); + added_row_idx_queue.push_back(row_idx); + agg_func->add(agg_state.data(), &input_col, row_idx, &arena); + res += input_decimal_vec[row_idx]; + } + + const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); + for (UInt32 k = 0; k < decrease_num; k++) + { + const UInt32 row_idx = added_row_idx_queue.front(); + added_row_idx_queue.pop_front(); + agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); + res -= input_decimal_vec[row_idx]; + } + + agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); + res_vec.push_back(res.toString(scale)); + } + } + + const auto nested_res_col = null_res_col->getNestedColumnPtr(); + size_t res_num = res_vec.size(); + ASSERT_EQ(res_num, null_res_col->size()); + + Field res_field; + for (size_t i = 0; i < res_num; i++) + { + ASSERT_FALSE(null_res_col->isNullAt(i)); + nested_res_col->get(i, res_field); + ASSERT_EQ(res_vec[i], getValue(res_field)); + } + } + } } CATCH From e3c31bd5d5b026de8f814d6307b735671d4b1acb Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 17 Dec 2024 12:03:01 +0800 Subject: [PATCH 11/32] refine tests --- .../tests/gtest_window_agg.cpp | 226 ++++++++---------- 1 file changed, 98 insertions(+), 128 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 9022ceb943c..9be33698a42 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -31,18 +31,38 @@ namespace DB namespace tests { constexpr int scale = 2; -const std::vector input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32"}; -const std::vector input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532"}; +const std::vector input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32", "11.11", "-10.32"}; +const std::vector input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532", "1111", "-1032"}; const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; -const std::vector input_decimal_vec{ - Decimal256(std::stoi(input_string_vec_aux[0])), - Decimal256(std::stoi(input_string_vec_aux[1])), - Decimal256(std::stoi(input_string_vec_aux[2])), - Decimal256(std::stoi(input_string_vec_aux[3])), - Decimal256(std::stoi(input_string_vec_aux[4])), - Decimal256(std::stoi(input_string_vec_aux[5])), - Decimal256(std::stoi(input_string_vec_aux[6])), - Decimal256(std::stoi(input_string_vec_aux[7]))}; +const std::vector input_decimal_vec{ + std::stoi(input_string_vec_aux[0]), + std::stoi(input_string_vec_aux[1]), + std::stoi(input_string_vec_aux[2]), + std::stoi(input_string_vec_aux[3]), + std::stoi(input_string_vec_aux[4]), + std::stoi(input_string_vec_aux[5]), + std::stoi(input_string_vec_aux[6]), + std::stoi(input_string_vec_aux[7])}; + +struct SumMocker +{ + inline static void add(Int64 & res, Int64 data) noexcept { res += data; } + inline static void decrease(Int64 & res, Int64 data) noexcept { res -= data; } +}; + +template +struct TestCase +{ + TestCase(DataTypePtr type_, const std::vector & input_vec_, int scale_) : type(type_), input_vec(input_vec_), scale(scale_) {} + + inline void addInMock(Int64 & res, Int64 row_idx) noexcept { mocker.add(res, input_vec[row_idx]); } + inline void decreaseInMock(Int64 & res, Int64 row_idx) noexcept { mocker.decrease(res, input_vec[row_idx]); } + + const DataTypePtr type; + const std::vector input_vec; + int scale; // scale is 0 when test type is int + OpMocker mocker; +}; class ExecutorWindowAgg : public DB::tests::AggregationTest { @@ -58,6 +78,9 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest } protected: + template + void executeWindowAggTest(TestCase & test_case); + static const IColumn * getInputColumn(const IDataType * type) { if (const auto * tmp = dynamic_cast(type); tmp != nullptr) @@ -114,140 +137,87 @@ DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, scale); DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, scale); -TEST_F(ExecutorWindowAgg, Sum) -try +template +void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) { Arena arena; auto context = TiFlashTestEnv::getContext(); std::deque added_row_idx_queue; - { - // Test for int - added_row_idx_queue.clear(); - auto agg_func - = AggregateFunctionFactory::instance().get(*context, "sum", {ExecutorWindowAgg::type_int}, {}, 0, true); - AlignedBuffer agg_state; - agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); - agg_func->create(agg_state.data()); - - std::vector res_int; - res_int.reserve(10); - - const UInt32 col_int_size = input_int_vec.size(); - - UInt32 reset_num = getResetNum(); - auto res_col = ExecutorWindowAgg::type_int->createColumn(); - auto null_map_col = ColumnUInt8::create(); - auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); - const IColumn * input_col = getInputColumn(ExecutorWindowAgg::type_int.get()); - - for (UInt32 i = 0; i < reset_num; i++) - { - Int64 res = 0; - agg_func->reset(agg_state.data()); + added_row_idx_queue.clear(); + auto agg_func + = AggregateFunctionFactory::instance().get(*context, "sum", {test_case.type}, {}, 0, true); + AlignedBuffer agg_state; + agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); + agg_func->create(agg_state.data()); - const UInt32 res_num = getResultNum(); - for (UInt32 j = 0; j < res_num; j++) - { - const UInt32 add_num = getAddNum(); - for (UInt32 k = 0; k < add_num; k++) - { - const UInt32 row_idx = getRowIdx(0, col_int_size - 1); - added_row_idx_queue.push_back(row_idx); - agg_func->add(agg_state.data(), &input_col, row_idx, &arena); - res += input_int_vec[row_idx]; - } - - const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); - for (UInt32 k = 0; k < decrease_num; k++) - { - const UInt32 row_idx = added_row_idx_queue.front(); - added_row_idx_queue.pop_front(); - agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); - res -= input_int_vec[row_idx]; - } - - agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); - res_int.push_back(res); - } - } + std::vector res_vec; + res_vec.reserve(10); - const auto nested_res_col = null_res_col->getNestedColumnPtr(); - size_t res_num = res_int.size(); - ASSERT_EQ(res_num, null_res_col->size()); - for (size_t i = 0; i < res_num; i++) - { - ASSERT_FALSE(null_res_col->isNullAt(i)); - ASSERT_EQ(res_int[i], nested_res_col->getInt(i)); - } - } + const UInt32 col_row_num = test_case.input_vec.size(); - { - // Test for decimal - DataTypes test_types{type_decimal128, type_decimal256}; - for (const auto & type : test_types) - { - added_row_idx_queue.clear(); - auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {type}, {}, 0, true); - AlignedBuffer agg_state; - agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); - agg_func->create(agg_state.data()); - - std::vector res_vec; - res_vec.reserve(10); + UInt32 reset_num = getResetNum(); + auto res_col = test_case.type->createColumn(); + auto null_map_col = ColumnUInt8::create(); + auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); + const IColumn * input_col = getInputColumn(test_case.type.get()); - const UInt32 col_decimal_size = input_decimal_vec.size(); - - UInt32 reset_num = getResetNum(); - auto res_col = type->createColumn(); - auto null_map_col = ColumnUInt8::create(); - auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); - const IColumn * input_col = getInputColumn(type.get()); + for (UInt32 i = 0; i < reset_num; i++) + { + Int64 res = 0; + agg_func->reset(agg_state.data()); - for (UInt32 i = 0; i < reset_num; i++) + const UInt32 res_num = getResultNum(); + for (UInt32 j = 0; j < res_num; j++) + { + const UInt32 add_num = getAddNum(); + for (UInt32 k = 0; k < add_num; k++) { - Decimal256 res(0); - agg_func->reset(agg_state.data()); - - const UInt32 res_num = getResultNum(); - for (UInt32 j = 0; j < res_num; j++) - { - const UInt32 add_num = getAddNum(); - for (UInt32 k = 0; k < add_num; k++) - { - const UInt32 row_idx = getRowIdx(0, col_decimal_size - 1); - added_row_idx_queue.push_back(row_idx); - agg_func->add(agg_state.data(), &input_col, row_idx, &arena); - res += input_decimal_vec[row_idx]; - } - - const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); - for (UInt32 k = 0; k < decrease_num; k++) - { - const UInt32 row_idx = added_row_idx_queue.front(); - added_row_idx_queue.pop_front(); - agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); - res -= input_decimal_vec[row_idx]; - } - - agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); - res_vec.push_back(res.toString(scale)); - } + const UInt32 row_idx = getRowIdx(0, col_row_num - 1); + added_row_idx_queue.push_back(row_idx); + agg_func->add(agg_state.data(), &input_col, row_idx, &arena); + test_case.addInMock(res, row_idx); } - const auto nested_res_col = null_res_col->getNestedColumnPtr(); - size_t res_num = res_vec.size(); - ASSERT_EQ(res_num, null_res_col->size()); - - Field res_field; - for (size_t i = 0; i < res_num; i++) + const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); + for (UInt32 k = 0; k < decrease_num; k++) { - ASSERT_FALSE(null_res_col->isNullAt(i)); - nested_res_col->get(i, res_field); - ASSERT_EQ(res_vec[i], getValue(res_field)); + const UInt32 row_idx = added_row_idx_queue.front(); + added_row_idx_queue.pop_front(); + agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); + test_case.decreaseInMock(res, row_idx); } + + agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); + res_vec.push_back(res); } } + + const auto nested_res_col = null_res_col->getNestedColumnPtr(); + size_t res_num = res_vec.size(); + ASSERT_EQ(res_num, null_res_col->size()); + + Field res_field; + for (size_t i = 0; i < res_num; i++) + { + ASSERT_FALSE(null_res_col->isNullAt(i)); + nested_res_col->get(i, res_field); + + // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result + ASSERT_EQ(Decimal256(res_vec[i]).toString(test_case.scale), getValue(res_field)); + } +} + +TEST_F(ExecutorWindowAgg, Sum) +try +{ + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, scale); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, scale); + + executeWindowAggTest(int_case); + executeWindowAggTest(decimal128_case); + executeWindowAggTest(decimal256_case); } CATCH From e3b9756c75cb63153cd286492c3759e984e4da73 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 17 Dec 2024 12:03:24 +0800 Subject: [PATCH 12/32] format --- .../tests/gtest_window_agg.cpp | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 9be33698a42..e22604518f6 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -31,8 +31,10 @@ namespace DB namespace tests { constexpr int scale = 2; -const std::vector input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32", "11.11", "-10.32"}; -const std::vector input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532", "1111", "-1032"}; +const std::vector + input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32", "11.11", "-10.32"}; +const std::vector + input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532", "1111", "-1032"}; const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; const std::vector input_decimal_vec{ std::stoi(input_string_vec_aux[0]), @@ -53,7 +55,11 @@ struct SumMocker template struct TestCase { - TestCase(DataTypePtr type_, const std::vector & input_vec_, int scale_) : type(type_), input_vec(input_vec_), scale(scale_) {} + TestCase(DataTypePtr type_, const std::vector & input_vec_, int scale_) + : type(type_) + , input_vec(input_vec_) + , scale(scale_) + {} inline void addInMock(Int64 & res, Int64 row_idx) noexcept { mocker.add(res, input_vec[row_idx]); } inline void decreaseInMock(Int64 & res, Int64 row_idx) noexcept { mocker.decrease(res, input_vec[row_idx]); } @@ -78,7 +84,7 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest } protected: - template + template void executeWindowAggTest(TestCase & test_case); static const IColumn * getInputColumn(const IDataType * type) @@ -137,7 +143,7 @@ DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, scale); DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, scale); -template +template void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) { Arena arena; @@ -145,8 +151,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) std::deque added_row_idx_queue; added_row_idx_queue.clear(); - auto agg_func - = AggregateFunctionFactory::instance().get(*context, "sum", {test_case.type}, {}, 0, true); + auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {test_case.type}, {}, 0, true); AlignedBuffer agg_state; agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); agg_func->create(agg_state.data()); @@ -196,7 +201,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) const auto nested_res_col = null_res_col->getNestedColumnPtr(); size_t res_num = res_vec.size(); ASSERT_EQ(res_num, null_res_col->size()); - + Field res_field; for (size_t i = 0; i < res_num; i++) { @@ -214,7 +219,7 @@ try TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, 0); TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, scale); TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, scale); - + executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); executeWindowAggTest(decimal256_case); From a1598579db80afe33e2b5af27e42c1b1278f402c Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 18 Dec 2024 13:50:03 +0800 Subject: [PATCH 13/32] fix bugs --- .../AggregateFunctionArray.h | 2 + .../AggregateFunctionCount.h | 8 +- .../AggregateFunctionForEach.h | 11 + .../AggregateFunctions/AggregateFunctionIf.h | 2 + .../AggregateFunctionMerge.h | 2 + .../AggregateFunctionMinMaxAny.h | 179 +++++++++----- .../AggregateFunctionNull.h | 10 + .../AggregateFunctionState.h | 2 + .../AggregateFunctions/AggregateFunctionSum.h | 4 +- .../tests/gtest_window_agg.cpp | 218 ++++++++++++++++-- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- 11 files changed, 355 insertions(+), 85 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index 9297f7ad543..144bbfdc378 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -112,6 +112,8 @@ class AggregateFunctionArray final : public IAggregateFunctionHelperreset(place); } + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index b474c7b5b49..4c9449670f9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -61,6 +61,8 @@ class AggregateFunctionCount final void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void addBatchSinglePlace( size_t start_offset, size_t batch_size, @@ -189,6 +191,8 @@ class AggregateFunctionCountNotNullUnary final void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; @@ -241,8 +245,6 @@ class AggregateFunctionCountNotNullVariadic final DataTypePtr getReturnType() const override { return std::make_shared(); } - void prepareWindow(AggregateDataPtr __restrict) const override {} - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { for (size_t i = 0; i < number_of_arguments; ++i) @@ -263,6 +265,8 @@ class AggregateFunctionCountNotNullVariadic final void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { data(place).count += data(rhs).count; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h index e718077e2ea..342a855db5b 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h @@ -212,6 +212,17 @@ class AggregateFunctionForEach final } } + void prepareWindow(AggregateDataPtr __restrict place) const override + { + AggregateFunctionForEachData & state = ensureAggregateData(place, 0, nullptr); + char * nested_state = state.array_of_aggregate_datas; + for (size_t i = 0; i < state.dynamic_array_size; i++) + { + nested_func->prepareWindow(nested_state); + nested_state += nested_size_of_data; + } + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { const AggregateFunctionForEachData & rhs_state = data(rhs); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index 45bf8092c21..7d3dec9b6c2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -85,6 +85,8 @@ class AggregateFunctionIf final : public IAggregateFunctionHelperreset(place); } + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void addBatch( size_t start_offset, size_t batch_size, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 32c87230f01..306ce6005dc 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -72,6 +72,8 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelperreset(place); } + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 63c39ac9613..f9e4df297bf 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -48,11 +48,13 @@ struct SingleValueDataFixed T value; // It's only used in window aggregation - mutable std::unique_ptr> saved_values; + mutable std::deque * saved_values; using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; public: + ~SingleValueDataFixed() { delete saved_values; } + bool has() const { return has_value; } void setCollators(const TiDB::TiDBCollators &) {} @@ -79,31 +81,31 @@ struct SingleValueDataFixed readBinary(value, buf); } - void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } template - void insertMinOrMaxResultInto(IColumn & to) + void insertMinOrMaxResultInto(IColumn & to) const { if (has()) { auto size = saved_values->size(); - value = (*saved_values)[0]; + T tmp = (*saved_values)[0]; for (size_t i = 1; i < size; i++) { if constexpr (is_min) { - if ((*saved_values)[i] < value) - value = (*saved_values)[i]; + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; } else { - if (value < (*saved_values)[i]) - value = (*saved_values)[i]; + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; } } - static_cast(to).getData().push_back(value); + static_cast(to).getData().push_back(tmp); } else { @@ -111,7 +113,11 @@ struct SingleValueDataFixed } } - void prepareWindow() { saved_values = std::make_unique>(); } + void prepareWindow() + { + saved_values = new std::deque(); + std::cout << "saved_values: " << saved_values << std::endl; + } void reset() { @@ -131,9 +137,6 @@ struct SingleValueDataFixed { has_value = true; value = static_cast(column).getData()[row_num]; - - if (saved_values) - saved_values->push_back(value); } /// Assuming to.has() @@ -141,8 +144,6 @@ struct SingleValueDataFixed { has_value = true; value = to.value; - if (saved_values) - saved_values->push_back(value); } bool changeFirstTime(const IColumn & column, size_t row_num, Arena * arena) @@ -186,7 +187,11 @@ struct SingleValueDataFixed bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) { - if (!has() || static_cast(column).getData()[row_num] < value) + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + if (!has() || to_value < value) { change(column, row_num, arena); return true; @@ -197,6 +202,9 @@ struct SingleValueDataFixed bool changeIfLess(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saved_values->push_back(to.value); + if (to.has() && (!has() || to.value < value)) { change(to, arena); @@ -208,7 +216,11 @@ struct SingleValueDataFixed bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) { - if (!has() || static_cast(column).getData()[row_num] > value) + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + if (!has() || to_value > value) { change(column, row_num, arena); return true; @@ -219,6 +231,9 @@ struct SingleValueDataFixed bool changeIfGreater(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saved_values->push_back(to.value); + if (to.has() && (!has() || to.value > value)) { change(to, arena); @@ -252,7 +267,7 @@ struct SingleValueDataString // TODO use std::string is inefficient // It's only used in window aggregation - mutable std::unique_ptr> saved_values; + mutable std::deque * saved_values{}; bool less(const StringRef & a, const StringRef & b) const { @@ -276,7 +291,7 @@ struct SingleValueDataString } public: - static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 72; + static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::unique_ptr>); @@ -284,6 +299,8 @@ struct SingleValueDataString char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. public: + ~SingleValueDataString() { delete saved_values; } + bool has() const { return size >= 0; } const char * getData() const { return size <= MAX_SMALL_STRING_SIZE ? small_data : large_data; } @@ -353,12 +370,12 @@ struct SingleValueDataString } } - void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } template - void insertMinOrMaxResultInto(IColumn & to) + void insertMinOrMaxResultInto(IColumn & to) const { if (has()) { @@ -387,7 +404,7 @@ struct SingleValueDataString } } - void prepareWindow() { saved_values = std::make_unique>(); } + void prepareWindow() { saved_values = new std::deque(); } void reset() { @@ -403,6 +420,8 @@ struct SingleValueDataString size = -1; } + void saveValue(StringRef value) { saved_values->push_back(value.toString()); } + /// Assuming to.has() void changeImpl(StringRef value, Arena * arena) { @@ -415,9 +434,6 @@ struct SingleValueDataString if (size > 0) memcpy(small_data, value.data, size); - - if (saved_values) - saved_values->push_back(std::string(small_data, size)); } else { @@ -430,9 +446,6 @@ struct SingleValueDataString size = value_size; memcpy(large_data, value.data, size); - - if (saved_values) - saved_values->push_back(std::string(large_data, size)); } } @@ -484,6 +497,9 @@ struct SingleValueDataString bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + if (!has() || less(static_cast(column).getDataAtWithTerminatingZero(row_num), getStringRef())) { @@ -496,6 +512,9 @@ struct SingleValueDataString bool changeIfLess(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + // todo should check the collator in `to` and `this` if (to.has() && (!has() || less(to.getStringRef(), getStringRef()))) { @@ -508,6 +527,9 @@ struct SingleValueDataString bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + if (!has() || greater(static_cast(column).getDataAtWithTerminatingZero(row_num), getStringRef())) { @@ -520,6 +542,9 @@ struct SingleValueDataString bool changeIfGreater(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + if (to.has() && (!has() || greater(to.getStringRef(), getStringRef()))) { change(to, arena); @@ -552,9 +577,11 @@ struct SingleValueDataGeneric Field value; // It's only used in window aggregation - std::unique_ptr> saved_values; + mutable std::deque * saved_values; public: + ~SingleValueDataGeneric() { delete saved_values; } + bool has() const { return !value.isNull(); } void setCollators(const TiDB::TiDBCollators &) {} @@ -587,31 +614,31 @@ struct SingleValueDataGeneric data_type.deserializeBinary(value, buf); } - void insertMaxResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - void insertMinResultInto(IColumn & to) { insertMinOrMaxResultInto(to); } + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } template - void insertMinOrMaxResultInto(IColumn & to) + void insertMinOrMaxResultInto(IColumn & to) const { if (has()) { auto size = saved_values->size(); - value = (*saved_values)[0]; + Field tmp = (*saved_values)[0]; for (size_t i = 1; i < size; i++) { if constexpr (is_min) { - if ((*saved_values)[i] < value) - value = (*saved_values)[i]; + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; } else { - if (value < (*saved_values)[i]) - value = (*saved_values)[i]; + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; } } - to.insert(value); + to.insert(tmp); } else { @@ -619,7 +646,7 @@ struct SingleValueDataGeneric } } - void prepareWindow() { saved_values = std::make_unique>(); } + void prepareWindow() { saved_values = new std::deque(); } void reset() { @@ -635,19 +662,9 @@ struct SingleValueDataGeneric value = Field(); } - void change(const IColumn & column, size_t row_num, Arena *) - { - column.get(row_num, value); - if (saved_values) - saved_values->push_back(value); - } + void change(const IColumn & column, size_t row_num, Arena *) { column.get(row_num, value); } - void change(const Self & to, Arena *) - { - value = to.value; - if (saved_values) - saved_values->push_back(value); - } + void change(const Self & to, Arena *) { value = to.value; } bool changeFirstTime(const IColumn & column, size_t row_num, Arena * arena) { @@ -693,12 +710,19 @@ struct SingleValueDataGeneric if (!has()) { change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); return true; } else { Field new_value; column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + if (new_value < value) { value = new_value; @@ -711,6 +735,9 @@ struct SingleValueDataGeneric bool changeIfLess(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saved_values->push_back(to.value); + if (to.has() && (!has() || to.value < value)) { change(to, arena); @@ -725,12 +752,19 @@ struct SingleValueDataGeneric if (!has()) { change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); return true; } else { Field new_value; column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + if (new_value > value) { value = new_value; @@ -743,6 +777,9 @@ struct SingleValueDataGeneric bool changeIfGreater(const Self & to, Arena * arena) { + if (saved_values != nullptr) + saved_values->push_back(to.value); + if (to.has() && (!has() || to.value > value)) { change(to, arena); @@ -774,7 +811,23 @@ struct AggregateFunctionMinData : Data } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } + void prepareWindow() + { + is_in_window = true; + Data::prepareWindow(); + } + + void insertResultInto(IColumn & to) const + { + if (is_in_window) + Data::insertMinResultInto(to); + else + Data::insertResultInto(to); + } + static const char * name() { return "min"; } + + bool is_in_window = false; }; template @@ -782,6 +835,20 @@ struct AggregateFunctionMaxData : Data { using Self = AggregateFunctionMaxData; + void prepareWindow() + { + is_in_window = true; + Data::prepareWindow(); + } + + void insertResultInto(IColumn & to) const + { + if (is_in_window) + Data::insertMaxResultInto(to); + else + Data::insertResultInto(to); + } + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) { return this->changeIfGreater(column, row_num, arena); @@ -789,6 +856,8 @@ struct AggregateFunctionMaxData : Data bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } static const char * name() { return "max"; } + + bool is_in_window = false; }; template @@ -931,8 +1000,6 @@ class AggregateFunctionsSingleValue final this->data(place).changeIfBetter(*columns[0], row_num, arena); } - void prepareWindow(AggregateDataPtr __restrict place) const override { this->data(place).prepareWindow(); } - void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override { this->data(place).decrease(); @@ -940,6 +1007,8 @@ class AggregateFunctionsSingleValue final void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + void prepareWindow(AggregateDataPtr __restrict place) const override { this->data(place).prepareWindow(); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { this->data(place).changeIfBetter(this->data(rhs), arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index 1beadc9634d..3eb1e73c3b5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -438,6 +438,11 @@ class AggregateFunctionNullUnary final this->nested_function->reset(this->nestedPlace(place)); } + void prepareWindow(AggregateDataPtr __restrict place) const override + { + this->nested_function->prepareWindow(this->nestedPlace(place)); + } + void addBatchSinglePlace( // NOLINT(google-default-arguments) size_t start_offset, size_t batch_size, @@ -554,6 +559,11 @@ class AggregateFunctionNullVariadic final this->nested_function->reset(this->nestedPlace(place)); } + void prepareWindow(AggregateDataPtr __restrict place) const override + { + this->nested_function->prepareWindow(this->nestedPlace(place)); + } + bool allocatesMemoryInArena() const override { return this->nested_function->allocatesMemoryInArena(); } private: diff --git a/dbms/src/AggregateFunctions/AggregateFunctionState.h b/dbms/src/AggregateFunctions/AggregateFunctionState.h index c535ac39fed..2a463038fc8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionState.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionState.h @@ -68,6 +68,8 @@ class AggregateFunctionState final : public IAggregateFunctionHelperreset(place); } + void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index 4025153008e..f768394ce3e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -359,8 +359,6 @@ class AggregateFunctionSum final new (place) Data; } - void prepareWindow(AggregateDataPtr __restrict) const override {} - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { const auto & column = assert_cast(*columns[0]); @@ -375,6 +373,8 @@ class AggregateFunctionSum final void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + void prepareWindow(AggregateDataPtr __restrict) const override {} + /// Vectorized version when there is no GROUP BY keys. void addBatchSinglePlace( size_t start_offset, diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index e22604518f6..4216c512c38 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -26,6 +26,8 @@ #include #include +#include + namespace DB { namespace tests @@ -50,22 +52,115 @@ struct SumMocker { inline static void add(Int64 & res, Int64 data) noexcept { res += data; } inline static void decrease(Int64 & res, Int64 data) noexcept { res -= data; } + inline static void reset() noexcept {} +}; + +struct CountMocker +{ + inline static void add(Int64 & res, Int64) noexcept { res++; } + inline static void decrease(Int64 & res, Int64) noexcept { res--; } + inline static void reset() noexcept {} +}; + +class AvgMocker +{ +public: + AvgMocker() + : sum(0) + , count(0) + {} + + inline void add(Int64 & res, Int64 data) noexcept + { + sum += data; + count++; + avgImpl(res); + } + + inline void decrease(Int64 & res, Int64 data) noexcept + { + sum -= data; + count--; + avgImpl(res); + } + + inline void reset() noexcept + { + sum = 0; + count = 0; + } + +private: + inline void avgImpl(Int64 & res) const noexcept { res = sum / count; } + + Int64 sum; + Int64 count; +}; + +template +class MinOrMaxMocker +{ +public: + inline void add(Int64 & res, Int64 data) noexcept + { + cmpAndChange(res, data); + saved_values.push_back(data); + } + + inline void decrease(Int64 & res, Int64) noexcept + { + saved_values.pop_front(); + res = is_max ? std::numeric_limits::min() : std::numeric_limits::max(); + + // Inefficient, but it's ok in the ut + for (auto value : saved_values) + cmpAndChange(res, value); + } + + inline void reset() noexcept { saved_values.clear(); } + +private: + static void inline cmpAndChange(Int64 & res, Int64 value) noexcept + { + if constexpr (is_max) + { + if (value > res) + res = value; + } + else + { + if (value < res) + res = value; + } + } + + std::deque saved_values; }; template struct TestCase { - TestCase(DataTypePtr type_, const std::vector & input_vec_, int scale_) + TestCase( + DataTypePtr type_, + const std::vector & input_vec_, + const String & agg_name_, + Int64 init_res_, + int scale_) : type(type_) , input_vec(input_vec_) + , agg_name(agg_name_) + , init_res(init_res_) , scale(scale_) {} inline void addInMock(Int64 & res, Int64 row_idx) noexcept { mocker.add(res, input_vec[row_idx]); } inline void decreaseInMock(Int64 & res, Int64 row_idx) noexcept { mocker.decrease(res, input_vec[row_idx]); } + inline void reset() noexcept { mocker.reset(); } const DataTypePtr type; const std::vector input_vec; + const String agg_name; + Int64 init_res; int scale; // scale is 0 when test type is int OpMocker mocker; }; @@ -105,6 +200,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest { case Field::Types::Which::Int64: return std::to_string(field.template get()); + case Field::Types::Which::UInt64: + return std::to_string(field.template get()); case Field::Types::Which::Decimal128: return field.template get>().toString(); case Field::Types::Which::Decimal256: @@ -151,10 +248,13 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) std::deque added_row_idx_queue; added_row_idx_queue.clear(); - auto agg_func = AggregateFunctionFactory::instance().get(*context, "sum", {test_case.type}, {}, 0, true); + auto agg_func + = AggregateFunctionFactory::instance().get(*context, test_case.agg_name, {test_case.type}, {}, 0, true); + auto return_type = agg_func->getReturnType(); AlignedBuffer agg_state; agg_state.reset(agg_func->sizeOfData(), agg_func->alignOfData()); agg_func->create(agg_state.data()); + agg_func->prepareWindow(agg_state.data()); std::vector res_vec; res_vec.reserve(10); @@ -162,19 +262,23 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) const UInt32 col_row_num = test_case.input_vec.size(); UInt32 reset_num = getResetNum(); - auto res_col = test_case.type->createColumn(); - auto null_map_col = ColumnUInt8::create(); - auto null_res_col = ColumnNullable::create(std::move(res_col), std::move(null_map_col)); + auto res_col = return_type->createColumn(); const IColumn * input_col = getInputColumn(test_case.type.get()); + // Start test for (UInt32 i = 0; i < reset_num; i++) { - Int64 res = 0; + Int64 res = test_case.init_res; + test_case.reset(); + std::cout << "----------- reset" << std::endl; agg_func->reset(agg_state.data()); + added_row_idx_queue.clear(); + // Generate a result const UInt32 res_num = getResultNum(); for (UInt32 j = 0; j < res_num; j++) { + // Start to add const UInt32 add_num = getAddNum(); for (UInt32 k = 0; k < add_num; k++) { @@ -182,32 +286,38 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) added_row_idx_queue.push_back(row_idx); agg_func->add(agg_state.data(), &input_col, row_idx, &arena); test_case.addInMock(res, row_idx); + std::cout << "add: " << test_case.input_vec[row_idx] << " res: " << res << std::endl; } - const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size()); - for (UInt32 k = 0; k < decrease_num; k++) + if likely (!added_row_idx_queue.empty()) { - const UInt32 row_idx = added_row_idx_queue.front(); - added_row_idx_queue.pop_front(); - agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); - test_case.decreaseInMock(res, row_idx); + // Start to decrease + const UInt32 decrease_num = getDecreaseNum(added_row_idx_queue.size() - 1); + for (UInt32 k = 0; k < decrease_num; k++) + { + const UInt32 row_idx = added_row_idx_queue.front(); + added_row_idx_queue.pop_front(); + agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); + test_case.decreaseInMock(res, row_idx); + std::cout << "decrease: " << test_case.input_vec[row_idx] << " res: " << res << std::endl; + } } - agg_func->insertResultInto(agg_state.data(), *null_res_col, &arena); + std::cout << "insert, res: " << res << std::endl; + agg_func->insertResultInto(agg_state.data(), *res_col, &arena); res_vec.push_back(res); } } - const auto nested_res_col = null_res_col->getNestedColumnPtr(); size_t res_num = res_vec.size(); - ASSERT_EQ(res_num, null_res_col->size()); + ASSERT_EQ(res_num, res_col->size()); Field res_field; for (size_t i = 0; i < res_num; i++) { - ASSERT_FALSE(null_res_col->isNullAt(i)); - nested_res_col->get(i, res_field); - + res_col->get(i, res_field); + ASSERT_FALSE(res_field.isNull()); + std::cout << "i: " << i << std::endl; // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result ASSERT_EQ(Decimal256(res_vec[i]).toString(test_case.scale), getValue(res_field)); } @@ -216,9 +326,9 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) TEST_F(ExecutorWindowAgg, Sum) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, scale); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, scale); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "sum", 0, 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "sum", 0, scale); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "sum", 0, scale); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -226,24 +336,82 @@ try } CATCH +// TODO add count distinct TEST_F(ExecutorWindowAgg, Count) try -{} +{ + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "count", 0, 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "count", 0, 0); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "count", 0, 0); + + executeWindowAggTest(int_case); + executeWindowAggTest(decimal128_case); + executeWindowAggTest(decimal256_case); +} CATCH TEST_F(ExecutorWindowAgg, Avg) try -{} +{ + // TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "avg", 0, 0); + // TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "avg", 0, 0); + // TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "avg", 0, 0); + + // executeWindowAggTest(int_case); + // executeWindowAggTest(decimal128_case); + // executeWindowAggTest(decimal256_case); +} CATCH +// TODO use unique_ptr in data TEST_F(ExecutorWindowAgg, Min) try -{} +{ + // TODO add string type etc... in AggregateFunctionMinMaxAny.h + TestCase> + int_case(ExecutorWindowAgg::type_int, input_int_vec, "min", std::numeric_limits::max(), 0); + TestCase> decimal128_case( + ExecutorWindowAgg::type_decimal128, + input_decimal_vec, + "min", + std::numeric_limits::max(), + scale); + TestCase> decimal256_case( + ExecutorWindowAgg::type_decimal256, + input_decimal_vec, + "min", + std::numeric_limits::max(), + scale); + + executeWindowAggTest(int_case); + executeWindowAggTest(decimal128_case); + executeWindowAggTest(decimal256_case); +} CATCH TEST_F(ExecutorWindowAgg, Max) try -{} +{ + // TODO add string type etc... in AggregateFunctionMinMaxAny.h + TestCase> + int_case(ExecutorWindowAgg::type_int, input_int_vec, "max", std::numeric_limits::min(), 0); + TestCase> decimal128_case( + ExecutorWindowAgg::type_decimal128, + input_decimal_vec, + "max", + std::numeric_limits::min(), + scale); + TestCase> decimal256_case( + ExecutorWindowAgg::type_decimal256, + input_decimal_vec, + "max", + std::numeric_limits::min(), + scale); + + executeWindowAggTest(int_case); + executeWindowAggTest(decimal128_case); + executeWindowAggTest(decimal256_case); +} CATCH } // namespace tests diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 1a8ac4e0167..5cea8f77c78 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -57,7 +57,7 @@ const std::unordered_map agg_func_map({ {tipb::ExprType::First, "first_row"}, {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, {tipb::ExprType::GroupConcat, "groupArray"}, - //{tipb::ExprType::Avg, ""}, + {tipb::ExprType::Avg, "avg"}, //{tipb::ExprType::Agg_BitAnd, ""}, //{tipb::ExprType::Agg_BitOr, ""}, //{tipb::ExprType::Agg_BitXor, ""}, From 0d4401dc70e577e3e0c180dcd34a203476a65d8b Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 18 Dec 2024 16:01:46 +0800 Subject: [PATCH 14/32] refine test --- .../tests/gtest_window_agg.cpp | 122 ++++++++++++------ 1 file changed, 80 insertions(+), 42 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 4216c512c38..2abe4d8644a 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -48,40 +48,73 @@ const std::vector input_decimal_vec{ std::stoi(input_string_vec_aux[6]), std::stoi(input_string_vec_aux[7])}; -struct SumMocker +class MockerBase { - inline static void add(Int64 & res, Int64 data) noexcept { res += data; } - inline static void decrease(Int64 & res, Int64 data) noexcept { res -= data; } - inline static void reset() noexcept {} +public: + explicit MockerBase(Int32 scale_) + : scale(scale_) + {} + + inline const std::vector & getResults() noexcept { return results; } + +protected: + inline String convertResIntToString(Int64 res) const noexcept { return Decimal256(res).toString(scale); } + + std::vector results; + Int32 scale; // scale is 0 when test type is int }; -struct CountMocker +class SumMocker : public MockerBase { - inline static void add(Int64 & res, Int64) noexcept { res++; } - inline static void decrease(Int64 & res, Int64) noexcept { res--; } - inline static void reset() noexcept {} +public: + explicit SumMocker(Int32 scale) + : MockerBase(scale) + {} + + inline void add(Int64 data) noexcept { res += data; } + inline void decrease(Int64 data) noexcept { res -= data; } + inline void reset() noexcept { res = 0; } + inline void saveResult() noexcept { results.push_back(convertResIntToString(res)); } + +private: + Int64 res{}; }; -class AvgMocker +class CountMocker : public MockerBase { public: - AvgMocker() - : sum(0) + explicit CountMocker(Int32 scale) + : MockerBase(scale) + {} + + inline void add(Int64) noexcept { res++; } + inline void decrease(Int64) noexcept { res--; } + inline void reset() noexcept { res = 0; } + inline void saveResult() noexcept { results.push_back(convertResIntToString(res)); } + +private: + Int64 res{}; +}; + +class AvgMocker : public MockerBase +{ +public: + explicit AvgMocker(Int32 scale) + : MockerBase(scale) + , sum(0) , count(0) {} - inline void add(Int64 & res, Int64 data) noexcept + inline void add(Int64 data) noexcept { sum += data; count++; - avgImpl(res); } - inline void decrease(Int64 & res, Int64 data) noexcept + inline void decrease(Int64 data) noexcept { sum -= data; count--; - avgImpl(res); } inline void reset() noexcept @@ -90,34 +123,42 @@ class AvgMocker count = 0; } + void saveResult() noexcept + { + WriteBufferFromOwnString wb; + writeFloatText(avgImpl(), wb); + results.push_back(wb.str()); + } + private: - inline void avgImpl(Int64 & res) const noexcept { res = sum / count; } + inline Float64 avgImpl() const noexcept { return static_cast(sum) / static_cast(count); } Int64 sum; Int64 count; }; template -class MinOrMaxMocker +class MinOrMaxMocker : public MockerBase { public: - inline void add(Int64 & res, Int64 data) noexcept - { - cmpAndChange(res, data); - saved_values.push_back(data); - } + explicit MinOrMaxMocker(Int32 scale) + : MockerBase(scale) + {} + + inline void add(Int64 data) noexcept { saved_values.push_back(data); } + inline void decrease(Int64) noexcept { saved_values.pop_front(); } + inline void reset() noexcept { saved_values.clear(); } - inline void decrease(Int64 & res, Int64) noexcept + inline void saveResult() noexcept { - saved_values.pop_front(); - res = is_max ? std::numeric_limits::min() : std::numeric_limits::max(); + Int64 res = is_max ? std::numeric_limits::min() : std::numeric_limits::max(); // Inefficient, but it's ok in the ut for (auto value : saved_values) cmpAndChange(res, value); - } - inline void reset() noexcept { saved_values.clear(); } + results.push_back(convertResIntToString(res)); + } private: static void inline cmpAndChange(Int64 & res, Int64 value) noexcept @@ -150,18 +191,19 @@ struct TestCase , input_vec(input_vec_) , agg_name(agg_name_) , init_res(init_res_) - , scale(scale_) + , mocker(scale_) {} - inline void addInMock(Int64 & res, Int64 row_idx) noexcept { mocker.add(res, input_vec[row_idx]); } - inline void decreaseInMock(Int64 & res, Int64 row_idx) noexcept { mocker.decrease(res, input_vec[row_idx]); } + inline void addInMock(Int64 row_idx) noexcept { mocker.add(input_vec[row_idx]); } + inline void decreaseInMock(Int64 row_idx) noexcept { mocker.decrease(input_vec[row_idx]); } inline void reset() noexcept { mocker.reset(); } + inline void saveResult() noexcept { mocker.saveResult(); } + inline const std::vector & getResults() noexcept { return mocker.getResults(); } const DataTypePtr type; const std::vector input_vec; const String agg_name; Int64 init_res; - int scale; // scale is 0 when test type is int OpMocker mocker; }; @@ -202,6 +244,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest return std::to_string(field.template get()); case Field::Types::Which::UInt64: return std::to_string(field.template get()); + case Field::Types::Which::Float64: + return std::to_string(field.template get()); case Field::Types::Which::Decimal128: return field.template get>().toString(); case Field::Types::Which::Decimal256: @@ -256,9 +300,6 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) agg_func->create(agg_state.data()); agg_func->prepareWindow(agg_state.data()); - std::vector res_vec; - res_vec.reserve(10); - const UInt32 col_row_num = test_case.input_vec.size(); UInt32 reset_num = getResetNum(); @@ -268,7 +309,6 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) // Start test for (UInt32 i = 0; i < reset_num; i++) { - Int64 res = test_case.init_res; test_case.reset(); std::cout << "----------- reset" << std::endl; agg_func->reset(agg_state.data()); @@ -285,8 +325,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) const UInt32 row_idx = getRowIdx(0, col_row_num - 1); added_row_idx_queue.push_back(row_idx); agg_func->add(agg_state.data(), &input_col, row_idx, &arena); - test_case.addInMock(res, row_idx); - std::cout << "add: " << test_case.input_vec[row_idx] << " res: " << res << std::endl; + test_case.addInMock(row_idx); } if likely (!added_row_idx_queue.empty()) @@ -298,17 +337,16 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) const UInt32 row_idx = added_row_idx_queue.front(); added_row_idx_queue.pop_front(); agg_func->decrease(agg_state.data(), &input_col, row_idx, &arena); - test_case.decreaseInMock(res, row_idx); - std::cout << "decrease: " << test_case.input_vec[row_idx] << " res: " << res << std::endl; + test_case.decreaseInMock(row_idx); } } - std::cout << "insert, res: " << res << std::endl; agg_func->insertResultInto(agg_state.data(), *res_col, &arena); - res_vec.push_back(res); + test_case.saveResult(); } } + const std::vector res_vec = test_case.getResults(); size_t res_num = res_vec.size(); ASSERT_EQ(res_num, res_col->size()); @@ -319,7 +357,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) ASSERT_FALSE(res_field.isNull()); std::cout << "i: " << i << std::endl; // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result - ASSERT_EQ(Decimal256(res_vec[i]).toString(test_case.scale), getValue(res_field)); + ASSERT_EQ(res_vec[i], getValue(res_field)); } } From 9ff2bd3a78e7945de5957aa72ee6d09ee0fa210d Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 18 Dec 2024 17:45:28 +0800 Subject: [PATCH 15/32] fix tests --- .../tests/gtest_window_agg.cpp | 97 ++++++++++--------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 2abe4d8644a..e2172da4cbe 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -48,6 +48,28 @@ const std::vector input_decimal_vec{ std::stoi(input_string_vec_aux[6]), std::stoi(input_string_vec_aux[7])}; +String eliminateTailing(String str) +{ + int point_idx = -1; + size_t size = str.size(); + for (size_t i = 0; i < size; i++) + { + if (str[i] == '.') + { + point_idx = i; + break; + } + } + + // Can't find point + if (point_idx == -1) + return str; + + if (static_cast(size) > point_idx + 3) + return String(str.c_str(), point_idx + 3); + return str; +} + class MockerBase { public: @@ -125,13 +147,15 @@ class AvgMocker : public MockerBase void saveResult() noexcept { - WriteBufferFromOwnString wb; - writeFloatText(avgImpl(), wb); - results.push_back(wb.str()); + if (scale == 0) + results.push_back(std::to_string(avgIntImpl())); + else + results.push_back(convertResIntToString(avgDecimalImpl())); } private: - inline Float64 avgImpl() const noexcept { return static_cast(sum) / static_cast(count); } + inline Float64 avgIntImpl() const noexcept { return static_cast(sum) / static_cast(count); } + inline Int64 avgDecimalImpl() const noexcept { return sum / count; } Int64 sum; Int64 count; @@ -181,16 +205,10 @@ class MinOrMaxMocker : public MockerBase template struct TestCase { - TestCase( - DataTypePtr type_, - const std::vector & input_vec_, - const String & agg_name_, - Int64 init_res_, - int scale_) + TestCase(DataTypePtr type_, const std::vector & input_vec_, const String & agg_name_, int scale_) : type(type_) , input_vec(input_vec_) , agg_name(agg_name_) - , init_res(init_res_) , mocker(scale_) {} @@ -203,7 +221,6 @@ struct TestCase const DataTypePtr type; const std::vector input_vec; const String agg_name; - Int64 init_res; OpMocker mocker; }; @@ -246,10 +263,12 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest return std::to_string(field.template get()); case Field::Types::Which::Float64: return std::to_string(field.template get()); + case Field::Types::Which::Decimal64: + return eliminateTailing(field.template get>().toString()); case Field::Types::Which::Decimal128: - return field.template get>().toString(); + return eliminateTailing(field.template get>().toString()); case Field::Types::Which::Decimal256: - return field.template get>().toString(); + return eliminateTailing(field.template get>().toString()); default: throw Exception("Invalid data type"); } @@ -310,7 +329,6 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) for (UInt32 i = 0; i < reset_num; i++) { test_case.reset(); - std::cout << "----------- reset" << std::endl; agg_func->reset(agg_state.data()); added_row_idx_queue.clear(); @@ -355,7 +373,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) { res_col->get(i, res_field); ASSERT_FALSE(res_field.isNull()); - std::cout << "i: " << i << std::endl; + // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result ASSERT_EQ(res_vec[i], getValue(res_field)); } @@ -364,9 +382,9 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) TEST_F(ExecutorWindowAgg, Sum) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "sum", 0, 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "sum", 0, scale); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "sum", 0, scale); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "sum", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "sum", scale); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "sum", scale); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -378,9 +396,9 @@ CATCH TEST_F(ExecutorWindowAgg, Count) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "count", 0, 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "count", 0, 0); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "count", 0, 0); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "count", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "count", 0); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "count", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -391,34 +409,32 @@ CATCH TEST_F(ExecutorWindowAgg, Avg) try { - // TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "avg", 0, 0); - // TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "avg", 0, 0); - // TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "avg", 0, 0); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "avg", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "avg", scale); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "avg", scale); - // executeWindowAggTest(int_case); - // executeWindowAggTest(decimal128_case); - // executeWindowAggTest(decimal256_case); + executeWindowAggTest(int_case); + executeWindowAggTest(decimal128_case); + executeWindowAggTest(decimal256_case); } CATCH // TODO use unique_ptr in data +// TODO ensure that if data will be called destructor TEST_F(ExecutorWindowAgg, Min) try { // TODO add string type etc... in AggregateFunctionMinMaxAny.h - TestCase> - int_case(ExecutorWindowAgg::type_int, input_int_vec, "min", std::numeric_limits::max(), 0); + TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, "min", 0); TestCase> decimal128_case( ExecutorWindowAgg::type_decimal128, input_decimal_vec, "min", - std::numeric_limits::max(), scale); TestCase> decimal256_case( ExecutorWindowAgg::type_decimal256, input_decimal_vec, "min", - std::numeric_limits::max(), scale); executeWindowAggTest(int_case); @@ -431,20 +447,9 @@ TEST_F(ExecutorWindowAgg, Max) try { // TODO add string type etc... in AggregateFunctionMinMaxAny.h - TestCase> - int_case(ExecutorWindowAgg::type_int, input_int_vec, "max", std::numeric_limits::min(), 0); - TestCase> decimal128_case( - ExecutorWindowAgg::type_decimal128, - input_decimal_vec, - "max", - std::numeric_limits::min(), - scale); - TestCase> decimal256_case( - ExecutorWindowAgg::type_decimal256, - input_decimal_vec, - "max", - std::numeric_limits::min(), - scale); + TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, "max", 0); + TestCase> decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "max", scale); + TestCase> decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "max", scale); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); From a9aa87995dbd3689776caae6fdbd99a141fa858e Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 19 Dec 2024 16:05:29 +0800 Subject: [PATCH 16/32] add test for string type --- .../AggregateFunctionMinMaxAny.h | 8 +- .../tests/gtest_window_agg.cpp | 233 +++++++++++++----- 2 files changed, 168 insertions(+), 73 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index f9e4df297bf..c7bd4b0ddb8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -113,11 +113,7 @@ struct SingleValueDataFixed } } - void prepareWindow() - { - saved_values = new std::deque(); - std::cout << "saved_values: " << saved_values << std::endl; - } + void prepareWindow() { saved_values = new std::deque(); } void reset() { @@ -396,7 +392,7 @@ struct SingleValueDataString } } - static_cast(to).insertDataWithTerminatingZero(getData(), size); + static_cast(to).insertDataWithTerminatingZero(value.data, value.size); } else { diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index e2172da4cbe..911549d6e9a 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -26,27 +26,39 @@ #include #include -#include +#include "DataTypes/DataTypeString.h" namespace DB { namespace tests { -constexpr int scale = 2; -const std::vector - input_string_vec{"0", "71.94", "12.34", "-34.26", "80.02", "-84.39", "28.41", "45.32", "11.11", "-10.32"}; -const std::vector - input_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532", "1111", "-1032"}; -const std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; -const std::vector input_decimal_vec{ - std::stoi(input_string_vec_aux[0]), - std::stoi(input_string_vec_aux[1]), - std::stoi(input_string_vec_aux[2]), - std::stoi(input_string_vec_aux[3]), - std::stoi(input_string_vec_aux[4]), - std::stoi(input_string_vec_aux[5]), - std::stoi(input_string_vec_aux[6]), - std::stoi(input_string_vec_aux[7])}; +const String CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()_+"; +const UInt32 CHARACTERS_LEN = CHARACTERS.size(); +constexpr int SCALE = 2; +std::vector input_decimal_in_string_vec{ + "0", + "71.94", + "12.34", + "-34.26", + "80.02", + "-84.39", + "28.41", + "45.32", + "11.11", + "-10.32"}; +std::vector + input_decimal_in_string_vec_aux{"0", "7194", "1234", "-3426", "8002", "-8439", "2841", "4532", "1111", "-1032"}; +std::vector input_int_vec{1, -2, 7, 4, 0, -3, -1, 0, 0, 9, 2, 0, -4, 2, 6, -3, 5}; +std::vector input_decimal_vec{ + std::stoi(input_decimal_in_string_vec_aux[0]), + std::stoi(input_decimal_in_string_vec_aux[1]), + std::stoi(input_decimal_in_string_vec_aux[2]), + std::stoi(input_decimal_in_string_vec_aux[3]), + std::stoi(input_decimal_in_string_vec_aux[4]), + std::stoi(input_decimal_in_string_vec_aux[5]), + std::stoi(input_decimal_in_string_vec_aux[6]), + std::stoi(input_decimal_in_string_vec_aux[7])}; +std::vector input_string_vec; String eliminateTailing(String str) { @@ -78,6 +90,8 @@ class MockerBase {} inline const std::vector & getResults() noexcept { return results; } + inline static void addString(const String &) { throw Exception("Not implemented yet"); } + inline static void decreaseString() { throw Exception("Not implemented yet"); } protected: inline String convertResIntToString(Int64 res) const noexcept { return Decimal256(res).toString(scale); } @@ -171,21 +185,39 @@ class MinOrMaxMocker : public MockerBase inline void add(Int64 data) noexcept { saved_values.push_back(data); } inline void decrease(Int64) noexcept { saved_values.pop_front(); } - inline void reset() noexcept { saved_values.clear(); } + inline void reset() noexcept + { + saved_values.clear(); + saved_string_values.clear(); + } + + inline void addString(const String & data) noexcept { saved_string_values.push_back(data); } + inline void decreaseString() noexcept { saved_string_values.pop_front(); } inline void saveResult() noexcept { - Int64 res = is_max ? std::numeric_limits::min() : std::numeric_limits::max(); - // Inefficient, but it's ok in the ut - for (auto value : saved_values) - cmpAndChange(res, value); - - results.push_back(convertResIntToString(res)); + if (saved_string_values.empty()) + { + Int64 res = saved_values[0]; + auto size = saved_values.size(); + for (size_t i = 1; i < size; i++) + cmpAndChange(res, saved_values[i]); + results.push_back(convertResIntToString(res)); + } + else + { + String res = saved_string_values[0]; + auto size = saved_string_values.size(); + for (size_t i = 1; i < size; i++) + cmpAndChange(res, saved_string_values[i]); + results.push_back(res); + } } private: - static void inline cmpAndChange(Int64 & res, Int64 value) noexcept + template + static void inline cmpAndChange(T & res, T value) noexcept { if constexpr (is_max) { @@ -200,26 +232,56 @@ class MinOrMaxMocker : public MockerBase } std::deque saved_values; + std::deque saved_string_values; }; template struct TestCase { - TestCase(DataTypePtr type_, const std::vector & input_vec_, const String & agg_name_, int scale_) + TestCase( + DataTypePtr type_, + const std::vector & input_int_vec_, + const std::vector & input_string_vec_, + const String & agg_name_, + int scale_) : type(type_) - , input_vec(input_vec_) + , input_int_vec(input_int_vec_) + , input_string_vec(input_string_vec_) , agg_name(agg_name_) , mocker(scale_) {} - inline void addInMock(Int64 row_idx) noexcept { mocker.add(input_vec[row_idx]); } - inline void decreaseInMock(Int64 row_idx) noexcept { mocker.decrease(input_vec[row_idx]); } + inline void addInMock(Int64 row_idx) noexcept + { + if (input_string_vec.empty()) + mocker.add(input_int_vec[row_idx]); + else + mocker.addString(input_string_vec[row_idx]); + } + + inline void decreaseInMock(Int64 row_idx) noexcept + { + if (input_string_vec.empty()) + mocker.decrease(input_int_vec[row_idx]); + else + mocker.decreaseString(); + } + + inline UInt32 getRowNum() noexcept + { + if (input_string_vec.empty()) + return input_int_vec.size(); + else + return input_string_vec.size(); + } + inline void reset() noexcept { mocker.reset(); } inline void saveResult() noexcept { mocker.saveResult(); } inline const std::vector & getResults() noexcept { return mocker.getResults(); } const DataTypePtr type; - const std::vector input_vec; + const std::vector input_int_vec; + const std::vector input_string_vec; const String agg_name; OpMocker mocker; }; @@ -229,6 +291,37 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest public: void SetUp() override { dre = std::default_random_engine(r()); } + static void SetUpTestCase() + { + AggregationTest::SetUpTestCase(); + + std::random_device r; + std::default_random_engine dre(r()); + std::uniform_int_distribution di; + + di.param(std::uniform_int_distribution::param_type{5, 15}); + auto elem_num = di(dre); + for (UInt32 i = 0; i < elem_num; i++) + { + di.param(std::uniform_int_distribution::param_type{0, 64}); + auto len = di(dre); + di.param(std::uniform_int_distribution::param_type{0, CHARACTERS_LEN - 1}); + + String str; + for (UInt32 j = 0; j < len; j++) + { + auto idx = di(dre); + str += CHARACTERS[idx]; + } + input_string_vec.push_back(str); + } + + input_int_col = createColumn(input_int_vec).column; + input_decimal128_col = createColumn(std::make_tuple(10, SCALE), input_decimal_in_string_vec).column; + input_decimal256_col = createColumn(std::make_tuple(30, SCALE), input_decimal_in_string_vec).column; + input_string_col = createColumn(input_string_vec).column; + } + private: // range: [begin, end] inline UInt32 rand(UInt32 begin, UInt32 end) noexcept @@ -249,6 +342,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest return &(*ExecutorWindowAgg::input_decimal128_col); else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) return &(*ExecutorWindowAgg::input_decimal256_col); + else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) + return &(*ExecutorWindowAgg::input_string_col); else throw Exception("Invalid data type"); } @@ -263,6 +358,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest return std::to_string(field.template get()); case Field::Types::Which::Float64: return std::to_string(field.template get()); + case Field::Types::Which::String: + return field.template get(); case Field::Types::Which::Decimal64: return eliminateTailing(field.template get>().toString()); case Field::Types::Which::Decimal128: @@ -284,24 +381,26 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest std::default_random_engine dre; std::uniform_int_distribution di; - static const ColumnPtr input_int_col; - static const ColumnPtr input_decimal128_col; - static const ColumnPtr input_decimal256_col; + static ColumnPtr input_int_col; + static ColumnPtr input_decimal128_col; + static ColumnPtr input_decimal256_col; + static ColumnPtr input_string_col; static DataTypePtr type_int; static DataTypePtr type_decimal128; static DataTypePtr type_decimal256; + static DataTypePtr type_string; }; -const ColumnPtr ExecutorWindowAgg::input_int_col = createColumn(input_int_vec).column; -const ColumnPtr ExecutorWindowAgg::input_decimal128_col - = createColumn(std::make_tuple(10, scale), input_string_vec).column; -const ColumnPtr ExecutorWindowAgg::input_decimal256_col - = createColumn(std::make_tuple(30, scale), input_string_vec).column; +ColumnPtr ExecutorWindowAgg::input_int_col; +ColumnPtr ExecutorWindowAgg::input_decimal128_col; +ColumnPtr ExecutorWindowAgg::input_decimal256_col; +ColumnPtr ExecutorWindowAgg::input_string_col; DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); -DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, scale); -DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, scale); +DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, SCALE); +DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, SCALE); +DataTypePtr ExecutorWindowAgg::type_string = std::make_shared(); template void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) @@ -319,7 +418,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) agg_func->create(agg_state.data()); agg_func->prepareWindow(agg_state.data()); - const UInt32 col_row_num = test_case.input_vec.size(); + const UInt32 col_row_num = test_case.getRowNum(); UInt32 reset_num = getResetNum(); auto res_col = return_type->createColumn(); @@ -328,6 +427,7 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) // Start test for (UInt32 i = 0; i < reset_num; i++) { + std::cout << "reset ---------" << std::endl; // TODO delete test_case.reset(); agg_func->reset(agg_state.data()); added_row_idx_queue.clear(); @@ -371,9 +471,9 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) Field res_field; for (size_t i = 0; i < res_num; i++) { + std::cout << "i: " << i << std::endl; // TODO delete res_col->get(i, res_field); ASSERT_FALSE(res_field.isNull()); - // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result ASSERT_EQ(res_vec[i], getValue(res_field)); } @@ -382,9 +482,9 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) TEST_F(ExecutorWindowAgg, Sum) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "sum", 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "sum", scale); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "sum", scale); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, {}, "sum", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, {}, "sum", SCALE); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "sum", SCALE); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -396,9 +496,9 @@ CATCH TEST_F(ExecutorWindowAgg, Count) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "count", 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "count", 0); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "count", 0); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, {}, "count", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, {}, "count", 0); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "count", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -409,9 +509,9 @@ CATCH TEST_F(ExecutorWindowAgg, Avg) try { - TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, "avg", 0); - TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "avg", scale); - TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "avg", scale); + TestCase int_case(ExecutorWindowAgg::type_int, input_int_vec, {}, "avg", 0); + TestCase decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, {}, "avg", SCALE); + TestCase decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "avg", SCALE); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); @@ -419,41 +519,40 @@ try } CATCH -// TODO use unique_ptr in data // TODO ensure that if data will be called destructor +// TODO DataTypeMyDuration SingleValueDataGeneric and check it indeed be tested TEST_F(ExecutorWindowAgg, Min) try { - // TODO add string type etc... in AggregateFunctionMinMaxAny.h - TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, "min", 0); - TestCase> decimal128_case( - ExecutorWindowAgg::type_decimal128, - input_decimal_vec, - "min", - scale); - TestCase> decimal256_case( - ExecutorWindowAgg::type_decimal256, - input_decimal_vec, - "min", - scale); + TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, {}, "min", 0); + TestCase> + decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, {}, "min", SCALE); + TestCase> + decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "min", SCALE); + TestCase> string_case(ExecutorWindowAgg::type_string, {}, input_string_vec, "min", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); executeWindowAggTest(decimal256_case); + executeWindowAggTest(string_case); } CATCH +// TODO DataTypeMyDuration SingleValueDataGeneric and check it indeed be tested TEST_F(ExecutorWindowAgg, Max) try { - // TODO add string type etc... in AggregateFunctionMinMaxAny.h - TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, "max", 0); - TestCase> decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, "max", scale); - TestCase> decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, "max", scale); + TestCase> int_case(ExecutorWindowAgg::type_int, input_int_vec, {}, "max", 0); + TestCase> + decimal128_case(ExecutorWindowAgg::type_decimal128, input_decimal_vec, {}, "max", SCALE); + TestCase> + decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "max", SCALE); + TestCase> string_case(ExecutorWindowAgg::type_string, {}, input_string_vec, "max", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); executeWindowAggTest(decimal256_case); + executeWindowAggTest(string_case); } CATCH From bf652bf0b618a4cf23835dd5668f63255d6e50d8 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 19 Dec 2024 16:31:28 +0800 Subject: [PATCH 17/32] add test for SingleValueDataGeneric type --- .../tests/gtest_window_agg.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 911549d6e9a..f78aeebe959 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -20,14 +20,14 @@ #include #include #include +#include #include +#include #include #include #include #include -#include "DataTypes/DataTypeString.h" - namespace DB { namespace tests @@ -59,6 +59,7 @@ std::vector input_decimal_vec{ std::stoi(input_decimal_in_string_vec_aux[6]), std::stoi(input_decimal_in_string_vec_aux[7])}; std::vector input_string_vec; +std::vector input_duration_vec{12, 43, 2, 0, 54, 23, 65, 76, 23, 12, 43, 56, 2, 2}; String eliminateTailing(String str) { @@ -320,6 +321,7 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest input_decimal128_col = createColumn(std::make_tuple(10, SCALE), input_decimal_in_string_vec).column; input_decimal256_col = createColumn(std::make_tuple(30, SCALE), input_decimal_in_string_vec).column; input_string_col = createColumn(input_string_vec).column; + input_duration_col = createColumn(input_duration_vec).column; } private: @@ -344,6 +346,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest return &(*ExecutorWindowAgg::input_decimal256_col); else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) return &(*ExecutorWindowAgg::input_string_col); + else if (const auto * tmp = dynamic_cast(type); tmp != nullptr) + return &(*ExecutorWindowAgg::input_duration_col); else throw Exception("Invalid data type"); } @@ -385,22 +389,26 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest static ColumnPtr input_decimal128_col; static ColumnPtr input_decimal256_col; static ColumnPtr input_string_col; + static ColumnPtr input_duration_col; static DataTypePtr type_int; static DataTypePtr type_decimal128; static DataTypePtr type_decimal256; static DataTypePtr type_string; + static DataTypePtr type_duration; }; ColumnPtr ExecutorWindowAgg::input_int_col; ColumnPtr ExecutorWindowAgg::input_decimal128_col; ColumnPtr ExecutorWindowAgg::input_decimal256_col; ColumnPtr ExecutorWindowAgg::input_string_col; +ColumnPtr ExecutorWindowAgg::input_duration_col; DataTypePtr ExecutorWindowAgg::type_int = std::make_shared(); DataTypePtr ExecutorWindowAgg::type_decimal128 = std::make_shared(10, SCALE); DataTypePtr ExecutorWindowAgg::type_decimal256 = std::make_shared(30, SCALE); DataTypePtr ExecutorWindowAgg::type_string = std::make_shared(); +DataTypePtr ExecutorWindowAgg::type_duration = std::make_shared(); template void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) @@ -520,7 +528,6 @@ try CATCH // TODO ensure that if data will be called destructor -// TODO DataTypeMyDuration SingleValueDataGeneric and check it indeed be tested TEST_F(ExecutorWindowAgg, Min) try { @@ -530,15 +537,16 @@ try TestCase> decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "min", SCALE); TestCase> string_case(ExecutorWindowAgg::type_string, {}, input_string_vec, "min", 0); + TestCase> duration_case(ExecutorWindowAgg::type_duration, input_duration_vec, {}, "min", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); executeWindowAggTest(decimal256_case); executeWindowAggTest(string_case); + executeWindowAggTest(duration_case); } CATCH -// TODO DataTypeMyDuration SingleValueDataGeneric and check it indeed be tested TEST_F(ExecutorWindowAgg, Max) try { @@ -548,11 +556,13 @@ try TestCase> decimal256_case(ExecutorWindowAgg::type_decimal256, input_decimal_vec, {}, "max", SCALE); TestCase> string_case(ExecutorWindowAgg::type_string, {}, input_string_vec, "max", 0); + TestCase> duration_case(ExecutorWindowAgg::type_duration, input_duration_vec, {}, "max", 0); executeWindowAggTest(int_case); executeWindowAggTest(decimal128_case); executeWindowAggTest(decimal256_case); executeWindowAggTest(string_case); + executeWindowAggTest(duration_case); } CATCH From 61eaef025c780bed3dae01d91f7f3ecd983c1690 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 10:40:15 +0800 Subject: [PATCH 18/32] tweaking --- dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp | 4 ---- dbms/src/DataStreams/WindowBlockInputStream.cpp | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index f78aeebe959..29c761e3e10 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -435,7 +435,6 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) // Start test for (UInt32 i = 0; i < reset_num; i++) { - std::cout << "reset ---------" << std::endl; // TODO delete test_case.reset(); agg_func->reset(agg_state.data()); added_row_idx_queue.clear(); @@ -479,7 +478,6 @@ void ExecutorWindowAgg::executeWindowAggTest(TestCase & test_case) Field res_field; for (size_t i = 0; i < res_num; i++) { - std::cout << "i: " << i << std::endl; // TODO delete res_col->get(i, res_field); ASSERT_FALSE(res_field.isNull()); // No matter what type the result is, we always use decimal to convert the result to string so that it's easy to check result @@ -500,7 +498,6 @@ try } CATCH -// TODO add count distinct TEST_F(ExecutorWindowAgg, Count) try { @@ -527,7 +524,6 @@ try } CATCH -// TODO ensure that if data will be called destructor TEST_F(ExecutorWindowAgg, Min) try { diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index de80f7d3a4c..d5f7da0a72f 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -1360,7 +1360,7 @@ void WindowTransformAction::appendBlock(Block & current_block) window_block.input_columns = current_block.getColumns(); } - +// TODO ensure and check that if data's destructor will be called and the allocated memory could be released // Update the aggregation states after the frame has changed. void WindowTransformAction::updateAggregationState() { From 2358efbdf807485199971cda01363a8db0534ac6 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 11:41:33 +0800 Subject: [PATCH 19/32] remove something --- dbms/src/Common/AlignedBuffer.cpp | 60 -------- dbms/src/Common/AlignedBuffer.h | 48 ------ .../DataStreams/WindowBlockInputStream.cpp | 73 +-------- dbms/src/DataStreams/WindowBlockInputStream.h | 70 +-------- dbms/src/WindowFunctions/tests/gtest_agg.cpp | 127 ---------------- tests/fullstack-test/mpp/window_agg.test | 139 ------------------ 6 files changed, 5 insertions(+), 512 deletions(-) delete mode 100644 dbms/src/Common/AlignedBuffer.cpp delete mode 100644 dbms/src/Common/AlignedBuffer.h delete mode 100644 dbms/src/WindowFunctions/tests/gtest_agg.cpp delete mode 100644 tests/fullstack-test/mpp/window_agg.test diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp deleted file mode 100644 index c9783bc3dfc..00000000000 --- a/dbms/src/Common/AlignedBuffer.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2023 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -namespace DB -{ - -namespace ErrorCodes -{ -extern const int CANNOT_ALLOCATE_MEMORY; -} - -void AlignedBuffer::alloc(size_t size, size_t alignment) -{ - void * new_buf; - int res = ::posix_memalign(&new_buf, std::max(alignment, sizeof(void *)), size); - if (0 != res) - throwFromErrno( - fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", size, alignment), - ErrorCodes::CANNOT_ALLOCATE_MEMORY, - res); - buf = new_buf; -} - -void AlignedBuffer::dealloc() -{ - if (buf) - ::free(buf); -} - -void AlignedBuffer::reset(size_t size, size_t alignment) -{ - dealloc(); - alloc(size, alignment); -} - -AlignedBuffer::AlignedBuffer(size_t size, size_t alignment) -{ - alloc(size, alignment); -} - -AlignedBuffer::~AlignedBuffer() -{ - dealloc(); -} - -} // namespace DB diff --git a/dbms/src/Common/AlignedBuffer.h b/dbms/src/Common/AlignedBuffer.h deleted file mode 100644 index cebf596cece..00000000000 --- a/dbms/src/Common/AlignedBuffer.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2024 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -namespace DB -{ - -/** Aligned piece of memory. - * It can only be allocated and destroyed. - * MemoryTracker is not used. AlignedBuffer is intended for small pieces of memory. - */ -class AlignedBuffer : private boost::noncopyable -{ -private: - void * buf = nullptr; - - void alloc(size_t size, size_t alignment); - void dealloc(); - -public: - AlignedBuffer() = default; - AlignedBuffer(size_t size, size_t alignment); - AlignedBuffer(AlignedBuffer && old) noexcept { std::swap(buf, old.buf); } - ~AlignedBuffer(); - - void reset(size_t size, size_t alignment); - - char * data() { return static_cast(buf); } - const char * data() const { return static_cast(buf); } -}; - -} // namespace DB diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index c10071583ec..49a5676b4fa 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -300,32 +300,11 @@ void WindowTransformAction::initialWorkspaces() WindowFunctionWorkspace workspace; workspace.window_function = window_function_description.window_function; workspace.arguments = window_function_description.arguments; - workspace.argument_column_indices = window_function_description.arguments; - workspace.argument_columns.assign(workspace.argument_column_indices.size(), nullptr); - initialAggregateFunction(workspace, window_function_description); workspaces.push_back(std::move(workspace)); } only_have_row_number = onlyHaveRowNumber(); } -void WindowTransformAction::initialAggregateFunction( - WindowFunctionWorkspace & workspace, - const WindowFunctionDescription & window_function_description) -{ - if (window_function_description.aggregate_function == nullptr) - return; - - has_agg = true; - - workspace.aggregate_function = window_function_description.aggregate_function; - const auto & aggregate_function = workspace.aggregate_function; - if (!arena && aggregate_function->allocatesMemoryInArena()) - arena = std::make_unique(); - - workspace.aggregate_function_state.reset(aggregate_function->sizeOfData(), aggregate_function->alignOfData()); - aggregate_function->create(workspace.aggregate_function_state.data()); -} - bool WindowBlockInputStream::returnIfCancelledOrKilled() { if (isCancelledOrThrowIfKilled()) @@ -1258,18 +1237,6 @@ void WindowTransformAction::writeOutCurrentRow() for (size_t wi = 0; wi < workspaces.size(); ++wi) { auto & ws = workspaces[wi]; - if (ws.window_function) - ws.window_function->windowInsertResultInto(*this, wi, ws.arguments); - else - { - const auto & block = blockAt(current_row); - IColumn * result_column = block.output_columns[wi].get(); - const auto * agg_func = ws.aggregate_function.get(); - auto * buf = ws.aggregate_function_state.data(); - - // TODO add `insertMergeResultInto` function? - agg_func->insertResultInto(buf, *result_column, arena.get()); - } } } @@ -1307,7 +1274,7 @@ bool WindowTransformAction::onlyHaveRowNumber() { for (const auto & workspace : workspaces) { - if (workspace.window_function != nullptr && workspace.window_function->getName() != "row_number") + if (workspace.window_function->getName() != "row_number") return false; } return true; @@ -1356,11 +1323,7 @@ void WindowTransformAction::appendBlock(Block & current_block) // Initialize output columns and add new columns to output block. for (auto & ws : workspaces) { - MutableColumnPtr res; - if (ws.window_function != nullptr) - res = ws.window_function->getReturnType()->createColumn(); - else - res = ws.aggregate_function->getReturnType()->createColumn(); + MutableColumnPtr res = ws.window_function->getReturnType()->createColumn(); res->reserve(window_block.rows); window_block.output_columns.push_back(std::move(res)); } @@ -1368,36 +1331,6 @@ void WindowTransformAction::appendBlock(Block & current_block) window_block.input_columns = current_block.getColumns(); } -// TODO ensure and check that if data's destructor will be called and the allocated memory could be released -// Update the aggregation states after the frame has changed. -void WindowTransformAction::updateAggregationState() -{ - if (!has_agg) - return; - - assert(frame_started); - assert(frame_ended); - assert(frame_start <= frame_end); - assert(prev_frame_start <= prev_frame_end); - assert(prev_frame_start <= frame_start); - assert(prev_frame_end <= frame_end); - assert(partition_start <= frame_start); - assert(frame_end <= partition_end); - - for (auto & ws : workspaces) - { - if (ws.window_function) - continue; - - // TODO compare the decrease and add number in previous frame - // when decrease > add, we create a new agg data to recalculate from start. - const RowNumber & end = frame_start <= prev_frame_end ? frame_start : prev_frame_end; - decreaseAggregationState(ws, prev_frame_start, end); - const RowNumber & start = frame_start <= prev_frame_end ? prev_frame_end : frame_start; - addAggregationState(ws, start, end); - } -} - void WindowTransformAction::tryCalculate() { // if there is no input data, we don't need to calculate @@ -1460,8 +1393,6 @@ void WindowTransformAction::tryCalculate() assert(frame_started); assert(frame_ended); - updateAggregationState(); - // Write out the results. // TODO execute the window function by block instead of row. writeOutCurrentRow(); diff --git a/dbms/src/DataStreams/WindowBlockInputStream.h b/dbms/src/DataStreams/WindowBlockInputStream.h index f4b7f9c3daa..b3c398b8356 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.h +++ b/dbms/src/DataStreams/WindowBlockInputStream.h @@ -52,6 +52,9 @@ struct WindowTransformAction Block tryGetOutputBlock(); void releaseAlreadyOutputWindowBlock(); + void initialWorkspaces(); + void initialPartitionAndOrderColumnIndices(); + Columns & inputAt(const RowNumber & x) { assert(x.block >= first_block_number); @@ -162,69 +165,6 @@ struct WindowTransformAction // distance is left - right. UInt64 distance(RowNumber left, RowNumber right); - void initialWorkspaces(); - void initialPartitionAndOrderColumnIndices(); - void initialAggregateFunction( - WindowFunctionWorkspace & workspace, - const WindowFunctionDescription & window_function_description); - - void addAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) - { - addOrDecreaseAggregationState(ws, start, end); - } - - void decreaseAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) - { - addOrDecreaseAggregationState(ws, start, end); - } - - template - void addOrDecreaseAggregationState(WindowFunctionWorkspace & ws, const RowNumber & start, const RowNumber & end) - { - if unlikely (start == end) - return; - - const auto * agg_func = ws.aggregate_function.get(); - auto * buf = ws.aggregate_function_state.data(); - - // Used for aggregate function. - // To achieve better performance, we will have to loop over blocks and - // rows manually, instead of using advanceRowNumber(). - // For this purpose, the end block can be different than the - // block of the end row (it's usually the next block). - const auto past_the_end_block = end.row == 0 ? end.block : end.block + 1; - - for (auto block_number = frame_start.block; block_number < past_the_end_block; ++block_number) - { - auto & block = blockAt(block_number); - - if (ws.cached_block_number != block_number) - { - for (size_t i = 0; i < ws.argument_column_indices.size(); ++i) - ws.argument_columns[i] = block.input_columns[ws.argument_column_indices[i]].get(); - ws.cached_block_number = block_number; - } - - // First and last blocks may be processed partially, and other blocks are processed in full. - const auto start_row = block_number == start.block ? start.row : 0; - const auto end_row = block_number == end.block ? end.row : block.rows; - auto * columns = ws.argument_columns.data(); - - auto * arena_ptr = arena.get(); - for (auto row = start_row; row < end_row; ++row) - { - if constexpr (is_add) - agg_func->add(buf, columns, row, arena_ptr); - else - agg_func->decrease(buf, columns, row, arena_ptr); - } - } - } - - void updateAggregationState(); - - void reinitializeAggFuncBeforeNextPartition(); - public: LoggerPtr log; @@ -242,8 +182,6 @@ struct WindowTransformAction // Per-window-function scratch spaces. std::vector workspaces; - bool has_agg; - // A sliding window of blocks we currently need. We add the input blocks as // they arrive, and discard the blocks we don't need anymore. The blocks // have an always-incrementing index. The index of the first block is in @@ -306,8 +244,6 @@ struct WindowTransformAction // Auxiliary variable for range frame type when calculating frame_end RowNumber prev_frame_end; - std::unique_ptr arena; - //TODO: used as template parameters bool only_have_row_number = false; }; diff --git a/dbms/src/WindowFunctions/tests/gtest_agg.cpp b/dbms/src/WindowFunctions/tests/gtest_agg.cpp deleted file mode 100644 index 84611237967..00000000000 --- a/dbms/src/WindowFunctions/tests/gtest_agg.cpp +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2024 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - - -namespace DB::tests -{ -class WindowAggFuncTest : public DB::tests::WindowTest -{ -public: - const ASTPtr value_col = col(VALUE_COL_NAME); - - void initializeContext() override { ExecutorTest::initializeContext(); } -}; - -TEST_F(WindowAggFuncTest, windowAggSumTests) -try -{ - { - // rows frame - MockWindowFrame frame; - frame.type = tipb::WindowFrameType::Rows; - frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); - frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); - std::vector frame_start_offset{0, 1, 3, 10}; - - std::vector> res{ - {0, 15, 14, 12, 8, 26, 41, 38, 28, 15, 18, 32, 49, 75, 66, 51, 31}, - {0, 15, 15, 14, 12, 26, 41, 41, 38, 28, 18, 33, 52, 80, 75, 66, 51}, - {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 83, 80, 75}, - {0, 15, 15, 15, 15, 26, 41, 41, 41, 41, 18, 33, 53, 84, 84, 84, 84}}; - - for (size_t i = 0; i < frame_start_offset.size(); ++i) - { - frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); - - executeFunctionAndAssert( - toVec(res[i]), - Sum(value_col), - {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), - toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), - toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, - frame); - } - } - - // TODO uncomment these test after range frame is merged - // { - // // range frame - // MockWindowFrame frame; - // frame.type = tipb::WindowFrameType::Rows; - // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); - // frame.end = buildRangeFrameBound(tipb::WindowBoundType::Following, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, true, 3); - // std::vector frame_start_offset{0, 1, 3, 10}; - - // std::vector> res_not_null{ - // {0, 7, 6, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, - // {0, 7, 7, 4, 8, 3, 3, 23, 28, 15, 4, 8, 5, 9, 15, 20, 31}, - // {0, 7, 7, 7, 8, 3, 3, 23, 38, 28, 4, 9, 8, 9, 15, 20, 31}, - // {0, 7, 7, 7, 15, 3, 3, 26, 41, 38, 4, 9, 9, 18, 29, 35, 31}}; - - // for (size_t i = 0; i < frame_start_offset.size(); ++i) - // { - // frame.start = buildRangeFrameBound(tipb::WindowBoundType::Preceding, tipb::RangeCmpDataType::Int, ORDER_COL_NAME, false, 0); - - // executeFunctionAndAssert( - // toVec(res_not_null[i]), - // Sum(value_col), - // {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), - // toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), - // toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, - // frame); - // } - // } -} -CATCH - -TEST_F(WindowAggFuncTest, windowAggCountTests) -try -{ - { - // rows frame - MockWindowFrame frame; - frame.type = tipb::WindowFrameType::Rows; - frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, 0); - frame.end = mock::MockWindowFrameBound(tipb::WindowBoundType::Following, false, 3); - std::vector frame_start_offset{0, 1, 3, 10}; - - std::vector> res{ - {1, 4, 3, 2, 1, 4, 4, 3, 2, 1, 4, 4, 4, 4, 3, 2, 1}, - {1, 4, 4, 3, 2, 4, 5, 4, 3, 2, 4, 5, 5, 5, 4, 3, 2}, - {1, 4, 4, 4, 4, 4, 5, 5, 5, 4, 4, 5, 6, 7, 6, 5, 4}, - {1, 4, 4, 4, 4, 4, 5, 5, 5, 5, 4, 5, 6, 7, 7, 7, 7}}; - - for (size_t i = 0; i < frame_start_offset.size(); ++i) - { - frame.start = mock::MockWindowFrameBound(tipb::WindowBoundType::Preceding, false, frame_start_offset[i]); - - executeFunctionAndAssert( - toVec(res[i]), - Count(value_col), - {toVec(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3}), - toVec(/*order*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31}), - toVec(/*value*/ {0, 1, 2, 4, 8, 0, 3, 10, 13, 15, 1, 3, 5, 9, 15, 20, 31})}, - frame); - } - } - // TODO add range frame tests after that is merged -} -CATCH -} // namespace DB::tests diff --git a/tests/fullstack-test/mpp/window_agg.test b/tests/fullstack-test/mpp/window_agg.test deleted file mode 100644 index 701eb4916f2..00000000000 --- a/tests/fullstack-test/mpp/window_agg.test +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2024 PingCAP, Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -mysql> drop table if exists test.agg; -mysql> create table test.agg(p int not null, o int not null, v int not null); -mysql> insert into test.agg (p, o, v) values (0, 0, 0), (1, 1, 1), (1, 2, 2), (1, 4, 4), (1, 8, 8), (2, 0, 0), (2, 3, 3), (2, 10, 10), (2, 13, 13), (2, 15, 15), (3, 1, 1), (3, 3, 3), (3, 5, 5), (3, 9, 9), (3, 15, 15), (3, 20, 20), (3, 31, 31); -mysql> alter table agg set tiflash replica 1; - -func> wait_table test test.agg - -mysql> use test; set tidb_enforce_mpp=1; - -//TODO ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg, ast.AggFuncMax, ast.AggFuncMin ... - -mysql> use test; set tidb_enforce_mpp=1; select *, sum(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; -+---+----+----+------+ -| p | o | v | a | -+---+----+----+------+ -| 0 | 0 | 0 | 0 | -| 1 | 1 | 1 | 3 | -| 1 | 2 | 2 | 7 | -| 1 | 4 | 4 | 14 | -| 1 | 8 | 8 | 12 | -| 2 | 0 | 0 | 3 | -| 2 | 3 | 3 | 13 | -| 2 | 10 | 10 | 26 | -| 2 | 13 | 13 | 38 | -| 2 | 15 | 15 | 28 | -| 3 | 1 | 1 | 4 | -| 3 | 3 | 3 | 9 | -| 3 | 5 | 5 | 17 | -| 3 | 9 | 9 | 29 | -| 3 | 15 | 15 | 44 | -| 3 | 20 | 20 | 66 | -| 3 | 31 | 31 | 51 | -+---+----+----+------+ - -mysql> use test; set tidb_enforce_mpp=1; select *, count(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; -+---+----+----+---+ -| p | o | v | a | -+---+----+----+---+ -| 0 | 0 | 0 | 1 | -| 1 | 1 | 1 | 2 | -| 1 | 2 | 2 | 3 | -| 1 | 4 | 4 | 3 | -| 1 | 8 | 8 | 2 | -| 2 | 0 | 0 | 2 | -| 2 | 3 | 3 | 3 | -| 2 | 10 | 10 | 3 | -| 2 | 13 | 13 | 3 | -| 2 | 15 | 15 | 2 | -| 3 | 1 | 1 | 2 | -| 3 | 3 | 3 | 3 | -| 3 | 5 | 5 | 3 | -| 3 | 9 | 9 | 3 | -| 3 | 15 | 15 | 3 | -| 3 | 20 | 20 | 3 | -| 3 | 31 | 31 | 2 | -+---+----+----+---+ - -mysql> use test; set tidb_enforce_mpp=1; select *, avg(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; -+---+----+----+---------+ -| p | o | v | a | -+---+----+----+---------+ -| 0 | 0 | 0 | 0.0000 | -| 1 | 1 | 1 | 1.5000 | -| 1 | 2 | 2 | 2.3333 | -| 1 | 4 | 4 | 4.6666 | -| 1 | 8 | 8 | 6.0000 | -| 2 | 0 | 0 | 1.5000 | -| 2 | 3 | 3 | 4.3333 | -| 2 | 10 | 10 | 8.6666 | -| 2 | 13 | 13 | 12.6666 | -| 2 | 15 | 15 | 14.0000 | -| 3 | 1 | 1 | 2.0000 | -| 3 | 3 | 3 | 3.0000 | -| 3 | 5 | 5 | 5.6666 | -| 3 | 9 | 9 | 9.6666 | -| 3 | 15 | 15 | 14.6666 | -| 3 | 20 | 20 | 22.0000 | -| 3 | 31 | 31 | 25.5000 | -+---+----+----+---------+ - -mysql> use test; set tidb_enforce_mpp=1; select *, min(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; -+---+----+----+------+ -| p | o | v | a | -+---+----+----+------+ -| 0 | 0 | 0 | 0 | -| 1 | 1 | 1 | 1 | -| 1 | 2 | 2 | 1 | -| 1 | 4 | 4 | 2 | -| 1 | 8 | 8 | 4 | -| 2 | 0 | 0 | 0 | -| 2 | 3 | 3 | 0 | -| 2 | 10 | 10 | 3 | -| 2 | 13 | 13 | 10 | -| 2 | 15 | 15 | 13 | -| 3 | 1 | 1 | 1 | -| 3 | 3 | 3 | 1 | -| 3 | 5 | 5 | 3 | -| 3 | 9 | 9 | 5 | -| 3 | 15 | 15 | 9 | -| 3 | 20 | 20 | 15 | -| 3 | 31 | 31 | 20 | -+---+----+----+------+ - -mysql> use test; set tidb_enforce_mpp=1; select *, max(v) over (partition by p order by o rows between 1 preceding and 1 following) as a from test.agg; -+---+----+----+------+ -| p | o | v | a | -+---+----+----+------+ -| 0 | 0 | 0 | 0 | -| 1 | 1 | 1 | 2 | -| 1 | 2 | 2 | 4 | -| 1 | 4 | 4 | 8 | -| 1 | 8 | 8 | 8 | -| 2 | 0 | 0 | 3 | -| 2 | 3 | 3 | 10 | -| 2 | 10 | 10 | 13 | -| 2 | 13 | 13 | 15 | -| 2 | 15 | 15 | 15 | -| 3 | 1 | 1 | 3 | -| 3 | 3 | 3 | 5 | -| 3 | 5 | 5 | 9 | -| 3 | 9 | 9 | 15 | -| 3 | 15 | 15 | 20 | -| 3 | 20 | 20 | 31 | -| 3 | 31 | 31 | 31 | -+---+----+----+------+ From 3227b509754edd247b422a500a9f178cb07a33fa Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 11:47:59 +0800 Subject: [PATCH 20/32] revoke --- .../DataStreams/WindowBlockInputStream.cpp | 2 +- dbms/src/Debug/MockExecutor/WindowBinder.cpp | 184 +++++------------- dbms/src/Debug/MockExecutor/WindowBinder.h | 4 - .../Coprocessor/DAGExpressionAnalyzer.cpp | 75 ++----- .../Flash/Coprocessor/DAGExpressionAnalyzer.h | 5 +- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- dbms/src/Interpreters/WindowDescription.h | 5 +- dbms/src/WindowFunctions/IWindowFunction.h | 14 +- 8 files changed, 73 insertions(+), 218 deletions(-) diff --git a/dbms/src/DataStreams/WindowBlockInputStream.cpp b/dbms/src/DataStreams/WindowBlockInputStream.cpp index 49a5676b4fa..136812d31d0 100644 --- a/dbms/src/DataStreams/WindowBlockInputStream.cpp +++ b/dbms/src/DataStreams/WindowBlockInputStream.cpp @@ -246,7 +246,6 @@ WindowTransformAction::WindowTransformAction( const String & req_id) : log(Logger::get(req_id)) , window_description(window_description_) - , has_agg(false) { output_header = input_header; for (const auto & add_column : window_description_.add_columns) @@ -1237,6 +1236,7 @@ void WindowTransformAction::writeOutCurrentRow() for (size_t wi = 0; wi < workspaces.size(); ++wi) { auto & ws = workspaces[wi]; + ws.window_function->windowInsertResultInto(*this, wi, ws.arguments); } } diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index 844cdf2298b..32f2e6b4d18 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -32,10 +32,12 @@ namespace // // Other window or aggregation functions always return nullable column and we need to // remove the not null flag for them. -void setFieldTypeForWindowFunc(tipb::Expr * window_expr, const tipb::ExprType window_sig, const int32_t collator_id) +void setWindowFieldType( + const tipb::ExprType window_sig, + tipb::FieldType * window_field_type, + tipb::Expr * window_expr, + const int32_t collator_id) { - window_expr->set_tp(window_sig); - auto * window_field_type = window_expr->mutable_field_type(); switch (window_sig) { case tipb::ExprType::Lead: @@ -84,51 +86,13 @@ void setFieldTypeForWindowFunc(tipb::Expr * window_expr, const tipb::ExprType wi } } -void setFieldTypeForAggFunc( - const DB::ASTFunction * func, - tipb::Expr * expr, - const tipb::ExprType agg_sig, - int32_t collator_id) +tipb::ExprType getWindowSig(const String & window_func_name) { - expr->set_tp(agg_sig); - if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum) - { - auto * ft = expr->mutable_field_type(); - ft->set_tp(TiDB::TypeLongLong); - } - else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max) - { - if (expr->children_size() != 1) - throw Exception(fmt::format("Agg function({}) only accept 1 argument", func->name)); - - auto * ft = expr->mutable_field_type(); - ft->set_tp(expr->children(0).field_type().tp()); - ft->set_decimal(expr->children(0).field_type().decimal()); - ft->set_collate(collator_id); - } - else - { - throw Exception("Window does not support this agg function"); - } - - expr->set_aggfuncmode(tipb::AggFunctionMode::FinalMode); -} - -void setFieldType(const DB::ASTFunction * func, tipb::Expr * expr, int32_t collator_id) -{ - auto window_sig_it = tests::window_func_name_to_sig.find(func->name); - if (window_sig_it != tests::window_func_name_to_sig.end()) - { - setFieldTypeForWindowFunc(expr, window_sig_it->second, collator_id); - return; - } - - auto agg_sig_it = tests::agg_func_name_to_sig.find(func->name); - if (agg_sig_it == tests::agg_func_name_to_sig.end()) - throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); + auto window_sig_it = tests::window_func_name_to_sig.find(window_func_name); + if (window_sig_it == tests::window_func_name_to_sig.end()) + throw Exception(fmt::format("Unsupported window function {}", window_func_name), ErrorCodes::LOGICAL_ERROR); - auto agg_sig = agg_sig_it->second; - setFieldTypeForAggFunc(func, expr, agg_sig, collator_id); + return window_sig_it->second; } void setWindowFrame(MockWindowFrame & frame, tipb::Window * window) @@ -189,7 +153,9 @@ bool WindowBinder::toTiPBExecutor( astToPB(input_schema, arg, window_expr->add_children(), collator_id, context); } - setFieldType(window_func, window_expr, collator_id); + auto window_sig = getWindowSig(window_func->name); + window_expr->set_tp(window_sig); + setWindowFieldType(window_sig, window_expr->mutable_field_type(), window_expr, collator_id); } for (const auto & child : order_by_exprs) @@ -217,89 +183,6 @@ bool WindowBinder::toTiPBExecutor( return children[0]->toTiPBExecutor(window->mutable_child(), collator_id, mpp_info, context); } -// This function can only be used in window agg -void setColumnInfoForAggInWindow( - TiDB::ColumnInfo & ci, - const DB::ASTFunction * func, - const std::vector & children_ci) -{ - // TODO: Other agg func. - if (func->name == "count") - { - ci.tp = TiDB::TypeLongLong; - ci.flag = TiDB::ColumnFlagUnsigned; - } - else if (func->name == "max" || func->name == "min" || func->name == "sum") - { - ci = children_ci[0]; - ci.flag &= ~TiDB::ColumnFlagNotNull; - } - else - { - throw Exception("Unsupported agg function: " + func->name, ErrorCodes::LOGICAL_ERROR); - } -} - -void setColumnInfoForWindowFunc( - TiDB::ColumnInfo & ci, - const DB::ASTFunction * func, - const std::vector & children_ci, - tipb::ExprType expr_type) -{ - // TODO: add more window functions - switch (expr_type) - { - case tipb::ExprType::RowNumber: - case tipb::ExprType::Rank: - case tipb::ExprType::DenseRank: - { - ci.tp = TiDB::TypeLongLong; - ci.flag = TiDB::ColumnFlagBinary; - break; - } - case tipb::ExprType::Lead: - case tipb::ExprType::Lag: - { - // TODO handling complex situations - // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) - assert(!children_ci.empty() && children_ci.size() <= 3); - if (children_ci.size() < 3) - { - ci = children_ci[0]; - ci.clearNotNullFlag(); - } - else - { - assert(children_ci[0].tp == children_ci[2].tp); - ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; - } - break; - } - case tipb::ExprType::FirstValue: - case tipb::ExprType::LastValue: - { - ci = children_ci[0]; - break; - } - default: - throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); - } -} - -TiDB::ColumnInfo createColumnInfo(const DB::ASTFunction * func, const std::vector & children_ci) -{ - TiDB::ColumnInfo ci; - auto iter = tests::window_func_name_to_sig.find(func->name); - if (iter != tests::window_func_name_to_sig.end()) - { - setColumnInfoForWindowFunc(ci, func, children_ci, iter->second); - return ci; - } - - setColumnInfoForAggInWindow(ci, func, children_ci); - return ci; -} - ExecutorBinderPtr compileWindow( ExecutorBinderPtr input, size_t & executor_index, @@ -358,7 +241,46 @@ ExecutorBinderPtr compileWindow( { children_ci.push_back(compileExpr(input->output_schema, arg)); } - TiDB::ColumnInfo ci = createColumnInfo(func, children_ci); + // TODO: add more window functions + TiDB::ColumnInfo ci; + switch (tests::window_func_name_to_sig[func->name]) + { + case tipb::ExprType::RowNumber: + case tipb::ExprType::Rank: + case tipb::ExprType::DenseRank: + { + ci.tp = TiDB::TypeLongLong; + ci.flag = TiDB::ColumnFlagBinary; + break; + } + case tipb::ExprType::Lead: + case tipb::ExprType::Lag: + { + // TODO handling complex situations + // like lead(col, offset, NULL), lead(data_type1, offset, data_type2) + assert(!children_ci.empty() && children_ci.size() <= 3); + if (children_ci.size() < 3) + { + ci = children_ci[0]; + ci.clearNotNullFlag(); + } + else + { + assert(children_ci[0].tp == children_ci[2].tp); + ci = children_ci[0].hasNotNullFlag() ? children_ci[2] : children_ci[0]; + } + break; + } + case tipb::ExprType::FirstValue: + case tipb::ExprType::LastValue: + { + ci = children_ci[0]; + ci.clearNotNullFlag(); + break; + } + default: + throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR); + } output_schema.emplace_back(std::make_pair(func->getColumnName(), ci)); } } diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.h b/dbms/src/Debug/MockExecutor/WindowBinder.h index 87b340fb834..11324167d18 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.h +++ b/dbms/src/Debug/MockExecutor/WindowBinder.h @@ -165,10 +165,6 @@ class WindowBinder : public ExecutorBinder const MPPInfo & mpp_info, const Context & context) override; -private: - void buildWindowFunc(); - void buildAggFunc(); - private: std::vector func_descs; std::vector partition_by_exprs; diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 83810a680d2..f419a693ff0 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -177,19 +177,17 @@ void appendAggDescription( /// Generate WindowFunctionDescription and append it to WindowDescription if need. void appendWindowDescription( - const Context & context, const Names & arg_names, const DataTypes & arg_types, TiDB::TiDBCollators & arg_collators, - const String & func_name, + const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns, - bool is_agg) + NamesAndTypes & window_columns) { assert(arg_names.size() == arg_collators.size() && arg_names.size() == arg_types.size()); - String func_string = genFuncString(func_name, arg_names, arg_collators); + String func_string = genFuncString(window_func_name, arg_names, arg_collators); if (auto duplicated_return_type = findDuplicateAggWindowFunc(func_string, window_description.window_functions_descriptions)) { @@ -201,44 +199,13 @@ void appendWindowDescription( WindowFunctionDescription window_function_description; window_function_description.argument_names = arg_names; window_function_description.column_name = func_string; - - DataTypePtr result_type; - if (is_agg) - { - window_function_description.aggregate_function - = AggregateFunctionFactory::instance().get(context, func_name, arg_types, {}, 0, true); - result_type = window_function_description.aggregate_function->getReturnType(); - } - else - { - window_function_description.window_function = WindowFunctionFactory::instance().get(func_name, arg_types); - result_type = window_function_description.window_function->getReturnType(); - } - + window_function_description.window_function = WindowFunctionFactory::instance().get(window_func_name, arg_types); + DataTypePtr result_type = window_function_description.window_function->getReturnType(); window_description.window_functions_descriptions.emplace_back(std::move(window_function_description)); window_columns.emplace_back(func_string, result_type); source_columns.emplace_back(func_string, result_type); } -bool isWindowFunction(const tipb::ExprType expr_type) -{ - switch (expr_type) - { - case tipb::ExprType::FirstValue: - case tipb::ExprType::LastValue: - case tipb::ExprType::RowNumber: - case tipb::ExprType::Rank: - case tipb::ExprType::DenseRank: - case tipb::ExprType::CumeDist: - case tipb::ExprType::PercentRank: - case tipb::ExprType::Ntile: - case tipb::ExprType::NthValue: - return true; - default: - return false; - } -} - void setAuxiliaryColumnInfoImpl( const String & aux_col_name, const Block & tmp_block, @@ -864,25 +831,22 @@ void DAGExpressionAnalyzer::buildLeadLag( } appendWindowDescription( - context, arg_names, arg_types, arg_collators, window_func_name, window_description, source_columns, - window_columns, - false); + window_columns); } -void DAGExpressionAnalyzer::buildWindowOrAggFuncImpl( +void DAGExpressionAnalyzer::buildCommonWindowFunc( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns, - bool is_agg) + NamesAndTypes & window_columns) { auto child_size = expr.children_size(); Names arg_names; @@ -894,15 +858,13 @@ void DAGExpressionAnalyzer::buildWindowOrAggFuncImpl( } appendWindowDescription( - context, arg_names, arg_types, arg_collators, window_func_name, window_description, source_columns, - window_columns, - is_agg); + window_columns); } // This function will add new window function culumns to source_column @@ -917,6 +879,7 @@ void DAGExpressionAnalyzer::appendWindowColumns( NamesAndTypes window_columns; for (const tipb::Expr & expr : window.func_desc()) { + RUNTIME_CHECK_MSG(isWindowFunctionExpr(expr), "Now Window Operator only support window function."); if (expr.tp() == tipb::ExprType::Lead || expr.tp() == tipb::ExprType::Lag) { buildLeadLag( @@ -927,27 +890,15 @@ void DAGExpressionAnalyzer::appendWindowColumns( source_columns, window_columns); } - else if (isWindowFunction(expr.tp())) - { - buildWindowOrAggFuncImpl( - expr, - actions, - getWindowFunctionName(expr), - window_description, - source_columns, - window_columns, - false); - } else { - buildWindowOrAggFuncImpl( + buildCommonWindowFunc( expr, actions, - getAggFunctionName(expr), + getWindowFunctionName(expr), window_description, source_columns, - window_columns, - true); + window_columns); } } window_description.add_columns = window_columns; diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h index b79a2743f9c..8ef4dbc0b78 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h @@ -228,14 +228,13 @@ class DAGExpressionAnalyzer : private boost::noncopyable NamesAndTypes & source_columns, NamesAndTypes & window_columns); - void buildWindowOrAggFuncImpl( + void buildCommonWindowFunc( const tipb::Expr & expr, const ExpressionActionsPtr & actions, const String & window_func_name, WindowDescription & window_description, NamesAndTypes & source_columns, - NamesAndTypes & window_columns, - bool is_agg); + NamesAndTypes & window_columns); void fillArgumentDetail( const ExpressionActionsPtr & actions, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 5cea8f77c78..1a8ac4e0167 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -57,7 +57,7 @@ const std::unordered_map agg_func_map({ {tipb::ExprType::First, "first_row"}, {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, {tipb::ExprType::GroupConcat, "groupArray"}, - {tipb::ExprType::Avg, "avg"}, + //{tipb::ExprType::Avg, ""}, //{tipb::ExprType::Agg_BitAnd, ""}, //{tipb::ExprType::Agg_BitOr, ""}, //{tipb::ExprType::Agg_BitXor, ""}, diff --git a/dbms/src/Interpreters/WindowDescription.h b/dbms/src/Interpreters/WindowDescription.h index 6fd1021c5e5..96270416bb2 100644 --- a/dbms/src/Interpreters/WindowDescription.h +++ b/dbms/src/Interpreters/WindowDescription.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include @@ -25,16 +24,16 @@ #include #include + namespace DB { struct WindowFunctionDescription { WindowFunctionPtr window_function; - AggregateFunctionPtr aggregate_function; Array parameters; ColumnNumbers arguments; Names argument_names; - String column_name; + std::string column_name; }; using WindowFunctionDescriptions = std::vector; diff --git a/dbms/src/WindowFunctions/IWindowFunction.h b/dbms/src/WindowFunctions/IWindowFunction.h index d75dfd9f2a8..e912efa809f 100644 --- a/dbms/src/WindowFunctions/IWindowFunction.h +++ b/dbms/src/WindowFunctions/IWindowFunction.h @@ -14,8 +14,6 @@ #pragma once -#include -#include #include #include #include @@ -54,18 +52,8 @@ using WindowFunctionPtr = std::shared_ptr; // Runtime data for computing one window function. struct WindowFunctionWorkspace { + // TODO add aggregation function WindowFunctionPtr window_function = nullptr; - AggregateFunctionPtr aggregate_function; - - // Will not be initialized for a pure window function. - mutable AlignedBuffer aggregate_function_state; - - // Argument columns. Be careful, this is a per-block cache. - std::vector argument_columns; - - UInt64 cached_block_number = std::numeric_limits::max(); - - ColumnNumbers argument_column_indices; ColumnNumbers arguments; }; From b2772eaa28e4aaa59bebb23fc4789f053c38a49b Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 14:21:26 +0800 Subject: [PATCH 21/32] add AlignedBuffer --- dbms/src/Common/AlignedBuffer.cpp | 60 +++++++++++++++++++++++++++++++ dbms/src/Common/AlignedBuffer.h | 48 +++++++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 dbms/src/Common/AlignedBuffer.cpp create mode 100644 dbms/src/Common/AlignedBuffer.h diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp new file mode 100644 index 00000000000..b2391383463 --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.cpp @@ -0,0 +1,60 @@ +// Copyright 2024 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int CANNOT_ALLOCATE_MEMORY; +} + +void AlignedBuffer::alloc(size_t size, size_t alignment) +{ + void * new_buf; + int res = ::posix_memalign(&new_buf, std::max(alignment, sizeof(void *)), size); + if (0 != res) + throwFromErrno( + fmt::format("Cannot allocate memory (posix_memalign), size: {}, alignment: {}.", size, alignment), + ErrorCodes::CANNOT_ALLOCATE_MEMORY, + res); + buf = new_buf; +} + +void AlignedBuffer::dealloc() +{ + if (buf) + ::free(buf); +} + +void AlignedBuffer::reset(size_t size, size_t alignment) +{ + dealloc(); + alloc(size, alignment); +} + +AlignedBuffer::AlignedBuffer(size_t size, size_t alignment) +{ + alloc(size, alignment); +} + +AlignedBuffer::~AlignedBuffer() +{ + dealloc(); +} + +} // namespace DB diff --git a/dbms/src/Common/AlignedBuffer.h b/dbms/src/Common/AlignedBuffer.h new file mode 100644 index 00000000000..cebf596cece --- /dev/null +++ b/dbms/src/Common/AlignedBuffer.h @@ -0,0 +1,48 @@ +// Copyright 2024 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB +{ + +/** Aligned piece of memory. + * It can only be allocated and destroyed. + * MemoryTracker is not used. AlignedBuffer is intended for small pieces of memory. + */ +class AlignedBuffer : private boost::noncopyable +{ +private: + void * buf = nullptr; + + void alloc(size_t size, size_t alignment); + void dealloc(); + +public: + AlignedBuffer() = default; + AlignedBuffer(size_t size, size_t alignment); + AlignedBuffer(AlignedBuffer && old) noexcept { std::swap(buf, old.buf); } + ~AlignedBuffer(); + + void reset(size_t size, size_t alignment); + + char * data() { return static_cast(buf); } + const char * data() const { return static_cast(buf); } +}; + +} // namespace DB From 37f4cc5de6d6df6bafde3b593b521bba8b14aed7 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 15:30:51 +0800 Subject: [PATCH 22/32] remove something --- .../AggregateFunctions/AggregateFunctionArgMinMax.h | 6 ------ .../AggregateFunctions/AggregateFunctionBitwise.h | 6 ------ .../AggregateFunctions/AggregateFunctionGroupArray.h | 6 ------ .../AggregateFunctionMaxIntersections.h | 6 ------ .../AggregateFunctions/AggregateFunctionQuantile.h | 6 ------ .../AggregateFunctionSequenceMatch.h | 6 ------ .../AggregateFunctions/AggregateFunctionStatistics.h | 6 ------ .../AggregateFunctionStatisticsSimple.h | 6 ------ .../src/AggregateFunctions/AggregateFunctionSumMap.h | 6 ------ dbms/src/AggregateFunctions/AggregateFunctionTopK.h | 6 ------ dbms/src/AggregateFunctions/AggregateFunctionUniq.h | 12 ------------ .../AggregateFunctions/AggregateFunctionUniqUpTo.h | 6 ------ dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- 13 files changed, 1 insertion(+), 79 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h index cba7bb9facd..f252c61b577 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h @@ -72,12 +72,6 @@ class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelperdata(place).result.change(*columns[0], row_num, arena); } - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - // TODO move to helper - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { if (this->data(place).value.changeIfBetter(this->data(rhs).value, arena)) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h index 28508aaf570..0c37c524770 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h @@ -62,12 +62,6 @@ class AggregateFunctionBitwise final : public IAggregateFunctionDataHelperdata(place).update(static_cast &>(*columns[0]).getData()[row_num]); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).update(this->data(rhs).value); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index 7910783e143..c80babc0502 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -80,12 +80,6 @@ class GroupArrayNumericImpl final this->data(place).value.push_back(static_cast &>(*columns[0]).getData()[row_num], arena); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { auto & cur_elems = this->data(place); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h index 36df63b35de..5455d6674e3 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h @@ -109,12 +109,6 @@ class AggregateFunctionIntersectionsMax final this->data(place).value.push_back(std::make_pair(right, static_cast(-1)), arena); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { auto & cur_elems = this->data(place); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h index cf0aea645f2..619eeeeab32 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -110,12 +110,6 @@ class AggregateFunctionQuantile final this->data(place).add(static_cast &>(*columns[0]).getData()[row_num]); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index bbd8e9fe686..d7e87fe6a45 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -206,12 +206,6 @@ class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelperdata(place).add(timestamp, events); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override - { - throw Exception("Not implemented yet"); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h index 87610e3e7eb..664dd66a142 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h @@ -129,12 +129,6 @@ class AggregateFunctionVariance final this->data(place).update(*columns[0], row_num); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).mergeWith(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index 8db4e6c3133..cf387284d2d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -221,12 +221,6 @@ class AggregateFunctionVarianceSimple final this->data(place).add(static_cast &>(*columns[0]).getData()[row_num]); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index ba4e7d08db2..8e3ef88f2cc 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -134,12 +134,6 @@ class AggregateFunctionSumMap final } } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, const size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { auto & merged_maps = this->data(place).merged_maps; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h index cd6ad86e08f..21af95500e4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h @@ -71,12 +71,6 @@ class AggregateFunctionTopK set.insert(static_cast &>(*columns[0]).getData()[row_num]); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h index b3c1c5c6be7..a54ef4165fa 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h @@ -438,12 +438,6 @@ class AggregateFunctionUniq final : public IAggregateFunctionDataHelper::add(this->data(place), *columns[0], row_num); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).set.merge(this->data(rhs).set); @@ -509,12 +503,6 @@ class AggregateFunctionUniqVariadic final UniqVariadicHash::apply(this->data(place), num_args, columns, row_num)); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).set.merge(this->data(rhs).set); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h index 6fa62a366be..1f66538e6fb 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h @@ -231,12 +231,6 @@ class AggregateFunctionUniqUpToVariadic final threshold); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs), threshold); diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 1a8ac4e0167..77acc88fa3b 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -57,7 +57,7 @@ const std::unordered_map agg_func_map({ {tipb::ExprType::First, "first_row"}, {tipb::ExprType::ApproxCountDistinct, uniq_raw_res_name}, {tipb::ExprType::GroupConcat, "groupArray"}, - //{tipb::ExprType::Avg, ""}, + {tipb::ExprType::Avg, "avg"}, // Only used in aggregation in window function //{tipb::ExprType::Agg_BitAnd, ""}, //{tipb::ExprType::Agg_BitOr, ""}, //{tipb::ExprType::Agg_BitXor, ""}, From dd52c9fbac602a3dceeceaef92e783728531370c Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 16:09:41 +0800 Subject: [PATCH 23/32] tweaking --- .../AggregateFunctionGroupArrayInsertAt.h | 6 ------ .../AggregateFunctionGroupConcat.h | 1 - .../AggregateFunctionMinMaxAny.cpp | 12 ++++++------ .../AggregateFunctions/AggregateFunctionMinMaxAny.h | 4 +--- .../AggregateFunctionSequenceMatch.h | 3 --- .../AggregateFunctions/AggregateFunctionStatistics.h | 6 ------ dbms/src/AggregateFunctions/AggregateFunctionTopK.h | 5 ----- .../AggregateFunctions/AggregateFunctionUniqUpTo.h | 5 ----- .../AggregateFunctions/tests/gtest_window_agg.cpp | 5 ++++- 9 files changed, 11 insertions(+), 36 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h index 2ae55afea21..1a71df32e41 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h @@ -150,12 +150,6 @@ class AggregateFunctionGroupArrayInsertAtGeneric final columns[0]->get(row_num, arr[position]); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { Array & arr_lhs = data(place).value; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h index 0dafef50a57..29a32b58cc5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h @@ -116,7 +116,6 @@ class AggregateFunctionGroupConcat final DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; } - /// reject nulls before add()/decrease() of nested agg template void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp index 84ba1a12853..16bc94aee1b 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.cpp @@ -123,14 +123,14 @@ AggregateFunctionPtr createAggregateFunctionArgMax( void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory) { - factory.registerFunction("any", createAggregateFunctionAny); // TODO no use - factory.registerFunction("first_row", createAggregateFunctionFirstRow); // TODO not used in window agg - factory.registerFunction("anyLast", createAggregateFunctionAnyLast); // TODO no use - factory.registerFunction("anyHeavy", createAggregateFunctionAnyHeavy); // TODO no use + factory.registerFunction("any", createAggregateFunctionAny); + factory.registerFunction("first_row", createAggregateFunctionFirstRow); + factory.registerFunction("anyLast", createAggregateFunctionAnyLast); + factory.registerFunction("anyHeavy", createAggregateFunctionAnyHeavy); factory.registerFunction("min", createAggregateFunctionMin, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("max", createAggregateFunctionMax, AggregateFunctionFactory::CaseInsensitive); - factory.registerFunction("argMin", createAggregateFunctionArgMin); // TODO no use - factory.registerFunction("argMax", createAggregateFunctionArgMax); // TODO no use + factory.registerFunction("argMin", createAggregateFunctionArgMin); + factory.registerFunction("argMax", createAggregateFunctionArgMax); } } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index c7bd4b0ddb8..5559efda00a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -34,8 +34,6 @@ namespace DB * For example: min, max, any, anyLast. */ -// TODO maybe we can create a new class to be inherited by SingleValueDataFixed, SingleValueDataString and SingleValueDataGeneric - /// For numeric values. template struct SingleValueDataFixed @@ -289,7 +287,7 @@ struct SingleValueDataString public: static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::unique_ptr>); + - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::deque *); private: char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index d7e87fe6a45..7c3935861b5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -27,9 +27,6 @@ #include #include -#include "Common/Exception.h" - - namespace DB { namespace ErrorCodes diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h index 664dd66a142..12d640180ff 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h @@ -377,12 +377,6 @@ class AggregateFunctionCovariance final this->data(place).update(*columns[0], *columns[1], row_num); } - // TODO move to helper - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).mergeWith(this->data(rhs)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h index 21af95500e4..e1e4b02e6ef 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h @@ -197,11 +197,6 @@ class AggregateFunctionTopKGeneric } } - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h index 1f66538e6fb..371d5eece87 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h @@ -158,11 +158,6 @@ class AggregateFunctionUniqUpTo final this->data(place).add(*columns[0], row_num, threshold); } - void decrease(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override - { - throw Exception(""); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs), threshold); diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 29c761e3e10..597dabcb07b 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. + +#include #include #include #include @@ -304,7 +306,8 @@ class ExecutorWindowAgg : public DB::tests::AggregationTest auto elem_num = di(dre); for (UInt32 i = 0; i < elem_num; i++) { - di.param(std::uniform_int_distribution::param_type{0, 64}); + di.param( + std::uniform_int_distribution::param_type{0, SingleValueDataString::MAX_SMALL_STRING_SIZE * 2}); auto len = di(dre); di.param(std::uniform_int_distribution::param_type{0, CHARACTERS_LEN - 1}); From 6379b0f0318f0540548d347a79471d2fbda4774c Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Fri, 20 Dec 2024 16:48:20 +0800 Subject: [PATCH 24/32] remove useless change --- .../AggregateFunctionGroupUniqArray.h | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index 872176d3f08..06dd57edf66 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -61,13 +61,6 @@ class AggregateFunctionGroupUniqArray this->data(place).value.insert(assert_cast &>(*columns[0]).getData()[row_num]); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override - { - const auto & key = AggregateFunctionGroupUniqArrayData::Set::Cell::getKey( - assert_cast &>(*columns[0]).getData()[row_num]); - this->data(place).value.erase(key); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); @@ -89,7 +82,7 @@ class AggregateFunctionGroupUniqArray void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - auto & arr_to = assert_cast(to); + ColumnArray & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); const typename State::Set & set = this->data(place).value; @@ -178,16 +171,6 @@ class AggregateFunctionGroupUniqArrayGeneric set.emplace(key_holder, it, inserted); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) - const override - { - auto & set = this->data(place).value; - - auto key_holder = getKeyHolder(*columns[0], row_num, *arena); - auto key = keyHolderGetKey(key_holder); - set.erase(key); - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { auto & cur_set = this->data(place).value; @@ -205,7 +188,7 @@ class AggregateFunctionGroupUniqArrayGeneric void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - auto & arr_to = assert_cast(to); + ColumnArray & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & data_to = arr_to.getData(); From 9e41b3a426850128b595222fb1d55895bf57386f Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Mon, 23 Dec 2024 11:13:40 +0800 Subject: [PATCH 25/32] fix ci --- dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp | 2 ++ dbms/src/Common/AlignedBuffer.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp index 597dabcb07b..7296aaf7448 100644 --- a/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp +++ b/dbms/src/AggregateFunctions/tests/gtest_window_agg.cpp @@ -30,6 +30,8 @@ #include #include +#include + namespace DB { namespace tests diff --git a/dbms/src/Common/AlignedBuffer.cpp b/dbms/src/Common/AlignedBuffer.cpp index b2391383463..f6798603419 100644 --- a/dbms/src/Common/AlignedBuffer.cpp +++ b/dbms/src/Common/AlignedBuffer.cpp @@ -38,7 +38,7 @@ void AlignedBuffer::alloc(size_t size, size_t alignment) void AlignedBuffer::dealloc() { if (buf) - ::free(buf); + ::free(buf); //NOLINT } void AlignedBuffer::reset(size_t size, size_t alignment) From e6adcbf3f8acffdb5b841fd90546c9af3990748f Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Mon, 23 Dec 2024 16:39:51 +0800 Subject: [PATCH 26/32] fix ci --- .../AggregateFunctions/AggregateFunctionMinMaxAny.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 5559efda00a..dd1af2373b7 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -51,6 +51,10 @@ struct SingleValueDataFixed using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; public: + SingleValueDataFixed() + : saved_values(nullptr) + {} + ~SingleValueDataFixed() { delete saved_values; } bool has() const { return has_value; } @@ -293,6 +297,9 @@ struct SingleValueDataString char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. public: + SingleValueDataString() + : saved_values(nullptr) + {} ~SingleValueDataString() { delete saved_values; } bool has() const { return size >= 0; } @@ -574,6 +581,9 @@ struct SingleValueDataGeneric mutable std::deque * saved_values; public: + SingleValueDataGeneric() + : saved_values(nullptr) + {} ~SingleValueDataGeneric() { delete saved_values; } bool has() const { return !value.isNull(); } From b84c9112b8d7a2699b4ccad523b2938756deb8be Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 24 Dec 2024 10:14:16 +0800 Subject: [PATCH 27/32] fix ut --- dbms/src/Core/tests/gtest_block.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbms/src/Core/tests/gtest_block.cpp b/dbms/src/Core/tests/gtest_block.cpp index 0f5e06b82d1..84dfe899b20 100644 --- a/dbms/src/Core/tests/gtest_block.cpp +++ b/dbms/src/Core/tests/gtest_block.cpp @@ -86,10 +86,10 @@ try "Nullable(Int64)", DataTypeString::getNullableDefaultName()}; std::vector data_size{ - 16, - ColumnString::APPROX_STRING_SIZE * 2, - 24, - ColumnString::APPROX_STRING_SIZE * 2 + 8}; + 32, + ColumnString::APPROX_STRING_SIZE * 2 + 8, + 40, + ColumnString::APPROX_STRING_SIZE * 2 + 16}; for (size_t i = 0; i < types.size(); ++i) { auto agg_data_type = DataTypeFactory::instance().get(fmt::format("AggregateFunction(Min, {})", types[i])); From aa122215cde73d79a3f6b07a7e4ddb62fef34626 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 25 Dec 2024 16:41:04 +0800 Subject: [PATCH 28/32] address some comments --- dbms/src/AggregateFunctions/AggregateFunctionAvg.h | 4 ++-- dbms/src/AggregateFunctions/AggregateFunctionMerge.h | 10 ---------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index 1e13affca86..b528f9f10f1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -91,10 +91,10 @@ class AggregateFunctionAvg final if constexpr (IsDecimal) this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; else - { this->data(place).sum -= static_cast &>(*columns[0]).getData()[row_num]; - } + --this->data(place).count; + assert(this->data(place).count >= 0); } void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 306ce6005dc..224ed534aff 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -64,16 +64,6 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelpermerge(place, static_cast(*columns[0]).getData()[row_num], arena); } - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) - const override - { - nested_func->merge(place, static_cast(*columns[0]).getData()[row_num], arena); - } - - void reset(AggregateDataPtr __restrict place) const override { nested_func->reset(place); } - - void prepareWindow(AggregateDataPtr __restrict place) const override { nested_func->prepareWindow(place); } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); From 2b9e991cdef86376ff51ad228adc95cde23412fd Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Thu, 26 Dec 2024 18:20:26 +0800 Subject: [PATCH 29/32] create new class for window --- .../AggregateFunctionFactory.cpp | 20 + .../AggregateFunctionFactory.h | 8 + .../AggregateFunctionMinMaxAny.h | 288 ++----------- .../AggregateFunctionMinMaxWindow.cpp | 56 +++ .../AggregateFunctionMinMaxWindow.h | 388 ++++++++++++++++++ .../AggregateFunctionNull.cpp | 67 ++- .../AggregateFunctionNull.h | 186 +++++---- .../src/AggregateFunctions/HelpersMinMaxAny.h | 35 ++ .../registerAggregateFunctions.cpp | 2 + .../tests/gtest_window_agg.cpp | 26 +- 10 files changed, 729 insertions(+), 347 deletions(-) create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index a62db765523..2c8624d8e90 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -107,6 +107,26 @@ AggregateFunctionPtr AggregateFunctionFactory::get( return res; } +AggregateFunctionPtr AggregateFunctionFactory::getForWindow( + const Context & context, + const String & name, + const DataTypes & argument_types, + const Array & parameters, + int recursion_level) const +{ + AggregateFunctionCombinatorPtr combinator + = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("NullForWindow"); + if (!combinator) + throw Exception( + "Logical error: cannot find aggregate function combinator to apply a function to Nullable for window " + "arguments.", + ErrorCodes::LOGICAL_ERROR); + + DataTypes nested_types = combinator->transformArguments(argument_types); + AggregateFunctionPtr nested_function = getImpl(context, name, nested_types, parameters, recursion_level); + return combinator->transformAggregateFunction(nested_function, argument_types, parameters); +} + AggregateFunctionPtr AggregateFunctionFactory::getImpl( const Context & context, const String & name, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h index 18d48296dee..756decb0692 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h @@ -67,6 +67,14 @@ class AggregateFunctionFactory final : public ext::Singleton #include -#include -#include - - namespace DB { /** Aggregate functions that store one of passed values. * For example: min, max, any, anyLast. */ +struct CommonImpl +{ + static void decrease() { throw Exception("Not implemented yet"); } + static void reset() { throw Exception("Not implemented yet"); } + static void prepareWindow() { throw Exception("Not implemented yet"); } +}; + /// For numeric values. template -struct SingleValueDataFixed +struct SingleValueDataFixed : public CommonImpl { -private: +protected: using Self = SingleValueDataFixed; bool has_value = false; /// We need to remember if at least one value has been passed. This is necessary for AggregateFunctionIf. T value; - // It's only used in window aggregation - mutable std::deque * saved_values; - using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; public: - SingleValueDataFixed() - : saved_values(nullptr) - {} - - ~SingleValueDataFixed() { delete saved_values; } - bool has() const { return has_value; } void setCollators(const TiDB::TiDBCollators &) {} @@ -83,54 +77,6 @@ struct SingleValueDataFixed readBinary(value, buf); } - void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - template - void insertMinOrMaxResultInto(IColumn & to) const - { - if (has()) - { - auto size = saved_values->size(); - T tmp = (*saved_values)[0]; - for (size_t i = 1; i < size; i++) - { - if constexpr (is_min) - { - if ((*saved_values)[i] < tmp) - tmp = (*saved_values)[i]; - } - else - { - if (tmp < (*saved_values)[i]) - tmp = (*saved_values)[i]; - } - } - static_cast(to).getData().push_back(tmp); - } - else - { - static_cast(to).insertDefault(); - } - } - - void prepareWindow() { saved_values = new std::deque(); } - - void reset() - { - has_value = false; - saved_values->clear(); - } - - // Only used for window aggregation - void decrease() - { - saved_values->pop_front(); - if unlikely (saved_values->empty()) - has_value = false; - } - void change(const IColumn & column, size_t row_num, Arena *) { has_value = true; @@ -185,11 +131,7 @@ struct SingleValueDataFixed bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) { - auto to_value = static_cast(column).getData()[row_num]; - if (saved_values != nullptr) - saved_values->push_back(to_value); - - if (!has() || to_value < value) + if (!has() || static_cast(column).getData()[row_num] < value) { change(column, row_num, arena); return true; @@ -200,9 +142,6 @@ struct SingleValueDataFixed bool changeIfLess(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saved_values->push_back(to.value); - if (to.has() && (!has() || to.value < value)) { change(to, arena); @@ -214,11 +153,7 @@ struct SingleValueDataFixed bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) { - auto to_value = static_cast(column).getData()[row_num]; - if (saved_values != nullptr) - saved_values->push_back(to_value); - - if (!has() || to_value > value) + if (!has() || static_cast(column).getData()[row_num] > value) { change(column, row_num, arena); return true; @@ -229,9 +164,6 @@ struct SingleValueDataFixed bool changeIfGreater(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saved_values->push_back(to.value); - if (to.has() && (!has() || to.value > value)) { change(to, arena); @@ -253,9 +185,9 @@ struct SingleValueDataFixed /** For strings. Short strings are stored in the object itself, and long strings are allocated separately. * NOTE It could also be suitable for arrays of numbers. */ -struct SingleValueDataString +struct SingleValueDataString : public CommonImpl { -private: +protected: using Self = SingleValueDataString; Int32 size = -1; /// -1 indicates that there is no value. @@ -263,10 +195,6 @@ struct SingleValueDataString char * large_data{}; TiDB::TiDBCollatorPtr collator{}; - // TODO use std::string is inefficient - // It's only used in window aggregation - mutable std::deque * saved_values{}; - bool less(const StringRef & a, const StringRef & b) const { if (unlikely(collator == nullptr)) @@ -290,18 +218,13 @@ struct SingleValueDataString public: static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; - static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr) - sizeof(std::deque *); + static constexpr Int32 MAX_SMALL_STRING_SIZE + = AUTOMATIC_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(TiDB::TiDBCollatorPtr); -private: +protected: char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. public: - SingleValueDataString() - : saved_values(nullptr) - {} - ~SingleValueDataString() { delete saved_values; } - bool has() const { return size >= 0; } const char * getData() const { return size <= MAX_SMALL_STRING_SIZE ? small_data : large_data; } @@ -371,58 +294,6 @@ struct SingleValueDataString } } - void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - template - void insertMinOrMaxResultInto(IColumn & to) const - { - if (has()) - { - auto elem_num = saved_values->size(); - StringRef value((*saved_values)[0].c_str(), (*saved_values)[0].size()); - for (size_t i = 1; i < elem_num; i++) - { - String cmp_value((*saved_values)[i].c_str(), (*saved_values)[i].size()); - if constexpr (is_min) - { - if (less(cmp_value, value)) - value = (*saved_values)[i]; - } - else - { - if (less(value, cmp_value)) - value = (*saved_values)[i]; - } - } - - static_cast(to).insertDataWithTerminatingZero(value.data, value.size); - } - else - { - static_cast(to).insertDefault(); - } - } - - void prepareWindow() { saved_values = new std::deque(); } - - void reset() - { - size = -1; - saved_values->clear(); - } - - // Only used for window aggregation - void decrease() - { - saved_values->pop_front(); - if unlikely (saved_values->empty()) - size = -1; - } - - void saveValue(StringRef value) { saved_values->push_back(value.toString()); } - /// Assuming to.has() void changeImpl(StringRef value, Arena * arena) { @@ -498,9 +369,6 @@ struct SingleValueDataString bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) { - if (saved_values != nullptr) - saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); - if (!has() || less(static_cast(column).getDataAtWithTerminatingZero(row_num), getStringRef())) { @@ -513,9 +381,6 @@ struct SingleValueDataString bool changeIfLess(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saveValue(to.getStringRef()); - // todo should check the collator in `to` and `this` if (to.has() && (!has() || less(to.getStringRef(), getStringRef()))) { @@ -528,9 +393,6 @@ struct SingleValueDataString bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) { - if (saved_values != nullptr) - saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); - if (!has() || greater(static_cast(column).getDataAtWithTerminatingZero(row_num), getStringRef())) { @@ -543,9 +405,6 @@ struct SingleValueDataString bool changeIfGreater(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saveValue(to.getStringRef()); - if (to.has() && (!has() || greater(to.getStringRef(), getStringRef()))) { change(to, arena); @@ -570,22 +429,14 @@ static_assert( /// For any other value types. -struct SingleValueDataGeneric +struct SingleValueDataGeneric : public CommonImpl { -private: +protected: using Self = SingleValueDataGeneric; Field value; - // It's only used in window aggregation - mutable std::deque * saved_values; - public: - SingleValueDataGeneric() - : saved_values(nullptr) - {} - ~SingleValueDataGeneric() { delete saved_values; } - bool has() const { return !value.isNull(); } void setCollators(const TiDB::TiDBCollators &) {} @@ -618,54 +469,6 @@ struct SingleValueDataGeneric data_type.deserializeBinary(value, buf); } - void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } - - template - void insertMinOrMaxResultInto(IColumn & to) const - { - if (has()) - { - auto size = saved_values->size(); - Field tmp = (*saved_values)[0]; - for (size_t i = 1; i < size; i++) - { - if constexpr (is_min) - { - if ((*saved_values)[i] < tmp) - tmp = (*saved_values)[i]; - } - else - { - if (tmp < (*saved_values)[i]) - tmp = (*saved_values)[i]; - } - } - to.insert(tmp); - } - else - { - to.insertDefault(); - } - } - - void prepareWindow() { saved_values = new std::deque(); } - - void reset() - { - value = Field(); - saved_values->clear(); - } - - // Only used for window aggregation - void decrease() - { - saved_values->pop_front(); - if unlikely (saved_values->empty()) - value = Field(); - } - void change(const IColumn & column, size_t row_num, Arena *) { column.get(row_num, value); } void change(const Self & to, Arena *) { value = to.value; } @@ -714,9 +517,6 @@ struct SingleValueDataGeneric if (!has()) { change(column, row_num, arena); - - if (saved_values != nullptr) - saved_values->push_back(value); return true; } else @@ -724,9 +524,6 @@ struct SingleValueDataGeneric Field new_value; column.get(row_num, new_value); - if (saved_values != nullptr) - saved_values->push_back(new_value); - if (new_value < value) { value = new_value; @@ -739,9 +536,6 @@ struct SingleValueDataGeneric bool changeIfLess(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saved_values->push_back(to.value); - if (to.has() && (!has() || to.value < value)) { change(to, arena); @@ -756,9 +550,6 @@ struct SingleValueDataGeneric if (!has()) { change(column, row_num, arena); - - if (saved_values != nullptr) - saved_values->push_back(value); return true; } else @@ -766,9 +557,6 @@ struct SingleValueDataGeneric Field new_value; column.get(row_num, new_value); - if (saved_values != nullptr) - saved_values->push_back(new_value); - if (new_value > value) { value = new_value; @@ -781,9 +569,6 @@ struct SingleValueDataGeneric bool changeIfGreater(const Self & to, Arena * arena) { - if (saved_values != nullptr) - saved_values->push_back(to.value); - if (to.has() && (!has() || to.value > value)) { change(to, arena); @@ -815,23 +600,11 @@ struct AggregateFunctionMinData : Data } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } - void prepareWindow() - { - is_in_window = true; - Data::prepareWindow(); - } + void prepareWindow() { throw Exception("Not implemented yet"); } - void insertResultInto(IColumn & to) const - { - if (is_in_window) - Data::insertMinResultInto(to); - else - Data::insertResultInto(to); - } + void insertResultInto(IColumn & to) const { Data::insertResultInto(to); } static const char * name() { return "min"; } - - bool is_in_window = false; }; template @@ -839,29 +612,16 @@ struct AggregateFunctionMaxData : Data { using Self = AggregateFunctionMaxData; - void prepareWindow() - { - is_in_window = true; - Data::prepareWindow(); - } - - void insertResultInto(IColumn & to) const - { - if (is_in_window) - Data::insertMaxResultInto(to); - else - Data::insertResultInto(to); - } + void insertResultInto(IColumn & to) const { Data::insertResultInto(to); } bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) { return this->changeIfGreater(column, row_num, arena); } + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } static const char * name() { return "max"; } - - bool is_in_window = false; }; template @@ -985,7 +745,9 @@ class AggregateFunctionsSingleValue final explicit AggregateFunctionsSingleValue(const DataTypePtr & type) : type(type) { - if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max")) + if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max") + || StringRef(Data::name()) == StringRef("max_for_window") + || StringRef(Data::name()) == StringRef("min_for_window")) { if (!type->isComparable()) throw Exception( diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp new file mode 100644 index 00000000000..3eaf9434036 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.cpp @@ -0,0 +1,56 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB +{ +namespace +{ +AggregateFunctionPtr createAggregateFunctionMinForWindow( + const Context & /* context not used */, + const std::string & name, + const DataTypes & argument_types, + const Array & parameters) +{ + return AggregateFunctionPtr( + createAggregateFunctionSingleValueForWindow( + name, + argument_types, + parameters)); +} + +AggregateFunctionPtr createAggregateFunctionMaxForWindow( + const Context & /* context not used */, + const std::string & name, + const DataTypes & argument_types, + const Array & parameters) +{ + return AggregateFunctionPtr( + createAggregateFunctionSingleValueForWindow( + name, + argument_types, + parameters)); +} +} // namespace + +void registerAggregateFunctionsMinMaxForWindow(AggregateFunctionFactory & factory) +{ + factory.registerFunction("min_for_window", createAggregateFunctionMinForWindow, AggregateFunctionFactory::CaseInsensitive); + factory.registerFunction("max_for_window", createAggregateFunctionMaxForWindow, AggregateFunctionFactory::CaseInsensitive); +} +} // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h new file mode 100644 index 00000000000..9c086f7bd27 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h @@ -0,0 +1,388 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +template +struct SingleValueDataFixedForWindow : public SingleValueDataFixed +{ +private: + using Self = SingleValueDataFixedForWindow; + using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; + + mutable std::deque * saved_values; +public: + SingleValueDataFixedForWindow() + : saved_values(nullptr) + {} + + ~SingleValueDataFixedForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (this->has()) + { + auto size = saved_values->size(); + T tmp = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; + } + else + { + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; + } + } + static_cast(to).getData().push_back(tmp); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + this->has_value = false; + saved_values->clear(); + } + + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + this->has_value = false; + } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + return SingleValueDataFixed::changeIfLess(column, row_num, arena); + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataFixed::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + auto to_value = static_cast(column).getData()[row_num]; + if (saved_values != nullptr) + saved_values->push_back(to_value); + + return SingleValueDataFixed::changeIfGreater(column, row_num, arena); + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataFixed::changeIfGreater(to, arena); + } +}; + +struct SingleValueDataStringForWindow : public SingleValueDataString +{ +private: + using Self = SingleValueDataStringForWindow; + + // TODO use std::string is inefficient + mutable std::deque * saved_values{}; + +public: + SingleValueDataStringForWindow() + : saved_values(nullptr) + {} + ~SingleValueDataStringForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (has()) + { + auto elem_num = saved_values->size(); + StringRef value((*saved_values)[0].c_str(), (*saved_values)[0].size()); + for (size_t i = 1; i < elem_num; i++) + { + String cmp_value((*saved_values)[i].c_str(), (*saved_values)[i].size()); + if constexpr (is_min) + { + if (less(cmp_value, value)) + value = (*saved_values)[i]; + } + else + { + if (less(value, cmp_value)) + value = (*saved_values)[i]; + } + } + + static_cast(to).insertDataWithTerminatingZero(value.data, value.size); + } + else + { + static_cast(to).insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + size = -1; + saved_values->clear(); + } + + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + size = -1; + } + + void saveValue(StringRef value) { saved_values->push_back(value.toString()); } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + + return SingleValueDataString::changeIfLess(column, row_num, arena); + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + + return SingleValueDataString::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + if (saved_values != nullptr) + saveValue(static_cast(column).getDataAtWithTerminatingZero(row_num)); + + + return SingleValueDataString::changeIfGreater(column, row_num, arena); + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saveValue(to.getStringRef()); + + return SingleValueDataString::changeIfGreater(to, arena); + } +}; + +struct SingleValueDataGenericForWindow : public SingleValueDataGeneric +{ +private: + using Self = SingleValueDataGenericForWindow; + mutable std::deque * saved_values; +public: + SingleValueDataGenericForWindow() + : saved_values(nullptr) + {} + ~SingleValueDataGenericForWindow() { delete saved_values; } + + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } + + template + void insertMinOrMaxResultInto(IColumn & to) const + { + if (has()) + { + auto size = saved_values->size(); + Field tmp = (*saved_values)[0]; + for (size_t i = 1; i < size; i++) + { + if constexpr (is_min) + { + if ((*saved_values)[i] < tmp) + tmp = (*saved_values)[i]; + } + else + { + if (tmp < (*saved_values)[i]) + tmp = (*saved_values)[i]; + } + } + to.insert(tmp); + } + else + { + to.insertDefault(); + } + } + + void prepareWindow() { saved_values = new std::deque(); } + + void reset() + { + value = Field(); + saved_values->clear(); + } + + // Only used for window aggregation + void decrease() + { + saved_values->pop_front(); + if unlikely (saved_values->empty()) + value = Field(); + } + + bool changeIfLess(const IColumn & column, size_t row_num, Arena * arena) + { + if (!has()) + { + change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); + return true; + } + else + { + Field new_value; + column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + + if (new_value < value) + { + value = new_value; + return true; + } + else + return false; + } + } + + bool changeIfLess(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataGeneric::changeIfLess(to, arena); + } + + bool changeIfGreater(const IColumn & column, size_t row_num, Arena * arena) + { + if (!has()) + { + change(column, row_num, arena); + + if (saved_values != nullptr) + saved_values->push_back(value); + return true; + } + else + { + Field new_value; + column.get(row_num, new_value); + + if (saved_values != nullptr) + saved_values->push_back(new_value); + + if (new_value > value) + { + value = new_value; + return true; + } + else + return false; + } + } + + bool changeIfGreater(const Self & to, Arena * arena) + { + if (saved_values != nullptr) + saved_values->push_back(to.value); + + return SingleValueDataGeneric::changeIfGreater(to, arena); + } +}; + +template +struct AggregateFunctionMinDataForWindow : Data +{ + using Self = AggregateFunctionMinDataForWindow; + + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) + { + return this->changeIfLess(column, row_num, arena); + } + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } + + void insertResultInto(IColumn & to) const { Data::insertMinResultInto(to); } + + static const char * name() { return "min_for_window"; } +}; + +template +struct AggregateFunctionMaxDataForWindow : Data +{ + using Self = AggregateFunctionMaxDataForWindow; + + void insertResultInto(IColumn & to) const { Data::insertMaxResultInto(to); } + + bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena) { return this->changeIfGreater(column, row_num, arena); } + + bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } + + static const char * name() { return "max_for_window"; } +}; + +} // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp index f639fd563a2..42da6365692 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include @@ -33,7 +34,7 @@ extern const std::unordered_set hacking_return_non_null_agg_func_names; class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinator { public: - String getName() const override { return "Null"; }; + String getName() const override { return "Null"; } DataTypes transformArguments(const DataTypes & arguments) const override { @@ -141,9 +142,73 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato } }; +class AggregateFunctionCombinatorNullForWindow final : public IAggregateFunctionCombinator +{ +public: + String getName() const override { return "NullForWindow"; } + + DataTypes transformArguments(const DataTypes & arguments) const override + { + size_t size = arguments.size(); + if (size != 1) + throw Exception(fmt::format("Aggregation in window accepts exact 1 argument, but gets {} arguments", size)); + DataTypes res(size); + res[0] = removeNullable(arguments[0]); + return res; + } + + AggregateFunctionPtr transformAggregateFunction( + const AggregateFunctionPtr & nested_function, + const DataTypes & arguments, + const Array &) const override + { + bool has_nullable_types = false; + bool has_null_types = false; + for (const auto & arg_type : arguments) + { + if (arg_type->isNullable()) + { + has_nullable_types = true; + if (arg_type->onlyNull()) + { + has_null_types = true; + break; + } + } + } + + if (nested_function && nested_function->getName() == "count") + { + if (has_nullable_types) + return std::make_shared(arguments[0]); + else + return std::make_shared(); + } + + if (has_null_types) + return std::make_shared(); + + if (nested_function->getReturnType()->canBeInsideNullable()) + { + if (has_nullable_types) + return std::make_shared>(nested_function); + else + return std::make_shared>(nested_function); + } + else + { + if (has_nullable_types) + return std::make_shared>(nested_function); + else + return std::make_shared>(nested_function); + } + } +}; + void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory & factory) { factory.registerCombinator(std::make_shared()); + factory.registerCombinator(std::make_shared()); } } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index 6432028739c..9ea4882434a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -90,13 +90,23 @@ class AggregateFunctionNullBase : public IAggregateFunctionHelper } public: - explicit AggregateFunctionNullBase(AggregateFunctionPtr nested_function_) + explicit AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, bool need_counter = false) : nested_function{nested_function_} { if (result_is_nullable) prefix_size = nested_function->alignOfData(); else prefix_size = 0; + + if (result_is_nullable && need_counter) + { + auto tmp = prefix_size; + prefix_size += sizeof(Int64); + if ((prefix_size % tmp) != 0) + prefix_size += tmp - (prefix_size % tmp); + assert((prefix_size % tmp) == 0); + assert(prefix_size >= (tmp + sizeof(Int64))); + } } String getName() const override @@ -266,18 +276,6 @@ class AggregateFunctionFirstRowNull size_t alignOfData() const override { return nested_function->alignOfData(); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override - { - addOrDecrease(place, columns, row_num, arena); - } - - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) - const override - { - addOrDecrease(place, columns, row_num, arena); - } - - template - void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) { @@ -289,20 +287,14 @@ class AggregateFunctionFirstRowNull if (!is_null) { const IColumn * nested_column = &column->getNestedColumn(); - if constexpr (is_add) - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); - else - this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); } } } else { this->setFlag(place, 1); - if constexpr (is_add) - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); - else - this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); } } @@ -397,18 +389,6 @@ class AggregateFunctionNullUnary final {} void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override - { - addOrDecrease(place, columns, row_num, arena); - } - - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) - const override - { - addOrDecrease(place, columns, row_num, arena); - } - - template - void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { if constexpr (input_is_nullable) { @@ -417,32 +397,16 @@ class AggregateFunctionNullUnary final { this->setFlag(place); const IColumn * nested_column = &column->getNestedColumn(); - if constexpr (is_add) - this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); - else - this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena); + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); } } else { this->setFlag(place); - if constexpr (is_add) - this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); - else - this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); } } - void reset(AggregateDataPtr __restrict place) const override - { - this->nested_function->reset(this->nestedPlace(place)); - } - - void prepareWindow(AggregateDataPtr __restrict place) const override - { - this->nested_function->prepareWindow(this->nestedPlace(place)); - } - void addBatchSinglePlace( // NOLINT(google-default-arguments) size_t start_offset, size_t batch_size, @@ -514,18 +478,6 @@ class AggregateFunctionNullVariadic final } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override - { - addOrDecrease(place, columns, row_num, arena); - } - - void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) - const override - { - addOrDecrease(place, columns, row_num, arena); - } - - template - void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const { /// This container stores the columns we really pass to the nested function. const IColumn * nested_columns[number_of_arguments]; @@ -548,31 +500,121 @@ class AggregateFunctionNullVariadic final } this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); + } + + bool allocatesMemoryInArena() const override { return this->nested_function->allocatesMemoryInArena(); } + +private: + enum + { + MAX_ARGS = 8 + }; + size_t number_of_arguments = 0; + std::array is_nullable; /// Plain array is better than std::vector due to one indirection less. +}; + +template +class AggregateFunctionNullUnaryForWindow final + : public AggregateFunctionNullBase< + result_is_nullable, + AggregateFunctionNullUnaryForWindow> +{ +public: + explicit AggregateFunctionNullUnaryForWindow(AggregateFunctionPtr nested_function) + : AggregateFunctionNullBase< + result_is_nullable, + AggregateFunctionNullUnaryForWindow>(nested_function, true) + {} + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + addOrDecrease(place, columns, row_num, arena); + } + + void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const override + { + addOrDecrease(place, columns, row_num, arena); + } + + template + void addOrDecreaseImpl(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) + const + { if constexpr (is_add) - this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); + { + this->addCounter(place); + this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), columns, row_num, arena); + } else - this->nested_function->decrease(this->nestedPlace(place), nested_columns, row_num, arena); + { + this->decreaseCounter(place); + if (this->getCounter(place) == 0) + this->resetFlag(place); + this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena); + } + } + + template + void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const + { + if constexpr (input_is_nullable) + { + const auto * column = static_cast(columns[0]); + if (!column->isNullAt(row_num)) + { + const IColumn * nested_column = &column->getNestedColumn(); + addOrDecreaseImpl(place, &nested_column, row_num, arena); + } + } + else + { + addOrDecreaseImpl(place, columns, row_num, arena); + } } void reset(AggregateDataPtr __restrict place) const override { + this->resetFlag(place); + this->resetCounter(place); this->nested_function->reset(this->nestedPlace(place)); } void prepareWindow(AggregateDataPtr __restrict place) const override { this->nested_function->prepareWindow(this->nestedPlace(place)); + reset(place); } - bool allocatesMemoryInArena() const override { return this->nested_function->allocatesMemoryInArena(); } - private: - enum + inline void resetFlag(AggregateDataPtr __restrict place) const noexcept { this->initFlag(place); } + + inline void addCounter(AggregateDataPtr __restrict place) const noexcept { - MAX_ARGS = 8 - }; - size_t number_of_arguments = 0; - std::array is_nullable; /// Plain array is better than std::vector due to one indirection less. + auto * counter = reinterpret_cast(place + 1); + ++(*counter); + } + + inline void decreaseCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + --(*counter); + assert((*counter) >= 0); + } + + inline Int64 getCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + return *counter; + } + + inline void resetCounter(AggregateDataPtr __restrict place) const noexcept + { + auto * counter = reinterpret_cast(place + 1); + (*counter) = 0; + } }; } // namespace DB diff --git a/dbms/src/AggregateFunctions/HelpersMinMaxAny.h b/dbms/src/AggregateFunctions/HelpersMinMaxAny.h index db96f948610..ad34a592e6d 100644 --- a/dbms/src/AggregateFunctions/HelpersMinMaxAny.h +++ b/dbms/src/AggregateFunctions/HelpersMinMaxAny.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -61,6 +62,40 @@ static IAggregateFunction * createAggregateFunctionSingleValue( return new AggregateFunctionTemplate>(argument_type); } +template