From 4c87257fd2d84da7cc161edb402de875c3cba4db Mon Sep 17 00:00:00 2001 From: Liqi Geng Date: Thu, 10 Oct 2024 19:22:31 +0800 Subject: [PATCH 1/4] Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507) Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 411 +++++++++++------- .../src/Functions/tests/gtest_string_left.cpp | 58 +-- .../Functions/tests/gtest_strings_right.cpp | 58 +-- dbms/src/Functions/tests/gtest_substring.cpp | 163 ++++++- dbms/src/TestUtils/FunctionTestUtils.h | 79 ++++ tests/fullstack-test/expr/substring_utf8.test | 14 +- 6 files changed, 530 insertions(+), 253 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index d856872bfe4..2773dad7343 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1673,25 +1673,46 @@ class FunctionSubstringUTF8 : public IFunction bool implicit_length = (arguments.size() == 2); - bool is_start_type_valid = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { - using StartType = std::decay_t; - // Int64 / UInt64 - using StartFieldType = typename StartType::FieldType; - - // vector const const - if (!column_string->isColumnConst() && column_start->isColumnConst() && (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst())) - { - auto [is_positive, start_abs] = getValueFromStartField((*block.getByPosition(arguments[1]).column)[0]); - UInt64 length = 0; - if (!implicit_length) - { - bool is_length_type_valid = getNumberType(block.getByPosition(arguments[2]).type, [&](const auto & length_type, bool) { - using LengthType = std::decay_t; - // Int64 / UInt64 - using LengthFieldType = typename LengthType::FieldType; - length = getValueFromLengthField((*block.getByPosition(arguments[2]).column)[0]); - return true; - }); + bool is_start_type_valid + = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { + using StartType = std::decay_t; + using StartFieldType = typename StartType::FieldType; + const ColumnVector * column_vector_start + = getInnerColumnVector(column_start); + if unlikely (!column_vector_start) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + // vector const const + if (!column_string->isColumnConst() && column_start->isColumnConst() + && (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst())) + { + auto [is_positive, start_abs] = getValueFromStartColumn(*column_vector_start, 0); + UInt64 length = 0; + if (!implicit_length) + { + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(block.getByPosition(arguments[2]).column); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + length = getValueFromLengthColumn(*column_vector_length, 0); + return true; + }); if (!is_length_type_valid) throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); @@ -1704,59 +1725,78 @@ class FunctionSubstringUTF8 : public IFunction return true; } - const auto * col = checkAndGetColumn(column_string.get()); - assert(col); - auto col_res = ColumnString::create(); - getVectorConstConstFunc(implicit_length, is_positive)(col->getChars(), col->getOffsets(), start_abs, length, col_res->getChars(), col_res->getOffsets()); - block.getByPosition(result).column = std::move(col_res); - } - else // all other cases are converted to vector vector vector - { - std::function(size_t)> get_start_func; - if (column_start->isColumnConst()) - { - // func always return const value - auto start_const = getValueFromStartField((*column_start)[0]); - get_start_func = [start_const](size_t) { - return start_const; - }; - } - else - { - get_start_func = [&column_start](size_t i) { - return getValueFromStartField((*column_start)[i]); - }; - } - - // if implicit_length, get_length_func be nil is ok. - std::function get_length_func; - if (!implicit_length) - { - const ColumnPtr & column_length = block.getByPosition(arguments[2]).column; - bool is_length_type_valid = getNumberType(block.getByPosition(arguments[2]).type, [&](const auto & length_type, bool) { - using LengthType = std::decay_t; - // Int64 / UInt64 - using LengthFieldType = typename LengthType::FieldType; - if (column_length->isColumnConst()) - { - // func always return const value - auto length_const = getValueFromLengthField((*column_length)[0]); - get_length_func = [length_const](size_t) { - return length_const; - }; - } - else - { - get_length_func = [column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); - }; - } - return true; - }); - - if (!is_length_type_valid) - throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); - } + const auto * col = checkAndGetColumn(column_string.get()); + assert(col); + auto col_res = ColumnString::create(); + getVectorConstConstFunc(implicit_length, is_positive)( + col->getChars(), + col->getOffsets(), + start_abs, + length, + col_res->getChars(), + col_res->getOffsets()); + block.getByPosition(result).column = std::move(col_res); + } + else // all other cases are converted to vector vector vector + { + std::function(size_t)> get_start_func; + if (column_start->isColumnConst()) + { + // func always return const value + auto start_const = getValueFromStartColumn(*column_vector_start, 0); + get_start_func = [start_const](size_t) { + return start_const; + }; + } + else + { + get_start_func = [column_vector_start](size_t i) { + return getValueFromStartColumn(*column_vector_start, i); + }; + } + + // if implicit_length, get_length_func be nil is ok. + std::function get_length_func; + if (!implicit_length) + { + const ColumnPtr & column_length = block.getByPosition(arguments[2]).column; + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (column_length->isColumnConst()) + { + // func always return const value + auto length_const + = getValueFromLengthColumn(*column_vector_length, 0); + get_length_func = [length_const](size_t) { + return length_const; + }; + } + else + { + get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); + }; + } + return true; + }); + + if unlikely (!is_length_type_valid) + throw Exception( + fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); + } // convert to vector if string is const. ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string; @@ -1777,10 +1817,38 @@ class FunctionSubstringUTF8 : public IFunction return true; }); - if (!is_start_type_valid) + if unlikely (!is_start_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); } + template + static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) + { + if (column->isColumnConst()) + return checkAndGetColumn>( + checkAndGetColumn(column.get())->getDataColumnPtr().get()); + return checkAndGetColumn>(column.get()); + } + + template + static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) + { + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) + { + return val < 0 ? 0 : val; + } + else + { + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return val; + } + } + private: using VectorConstConstFunc = std::function - static size_t getValueFromLengthField(const Field & length_field) - { - if constexpr (std::is_same_v) - { - Int64 signed_length = length_field.get(); - return signed_length < 0 ? 0 : signed_length; - } - else - { - static_assert(std::is_same_v); - return length_field.get(); - } - } - // return {is_positive, abs} template - static std::pair getValueFromStartField(const Field & start_field) + static std::pair getValueFromStartColumn(const ColumnVector & column, size_t index) { - if constexpr (std::is_same_v) + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) { - Int64 signed_length = start_field.get(); - - if (signed_length < 0) - { - return {false, static_cast(-signed_length)}; - } - else - { - return {true, static_cast(signed_length)}; - } + if (val < 0) + return {false, static_cast(-val)}; + return {true, static_cast(val)}; } else { - static_assert(std::is_same_v); - return {true, start_field.get()}; + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return {true, val}; } } @@ -1845,8 +1896,14 @@ class FunctionSubstringUTF8 : public IFunction static bool getNumberType(DataTypePtr type, F && f) { return castTypeToEither< - DataTypeInt64, - DataTypeUInt64>(type.get(), std::forward(f)); + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); } }; @@ -1891,18 +1948,31 @@ class FunctionRightUTF8 : public IFunction const ColumnPtr column_string = block.getByPosition(arguments[0]).column; const ColumnPtr column_length = block.getByPosition(arguments[1]).column; - bool is_length_type_valid = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { - using LengthType = std::decay_t; - // Int64 / UInt64 - using LengthFieldType = typename LengthType::FieldType; - - auto col_res = ColumnString::create(); - if (const auto * col_string = checkAndGetColumn(column_string.get())) - { - if (column_length->isColumnConst()) - { - // vector const - size_t length = getValueFromLengthField((*column_length)[0]); + bool is_length_type_valid + = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + + const ColumnVector * column_vector_length + = FunctionSubstringUTF8::getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + + auto col_res = ColumnString::create(); + if (const auto * col_string = checkAndGetColumn(column_string.get())) + { + if (column_length->isColumnConst()) + { + // vector const + size_t length = FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + 0); // for const 0, return const blank string. if (0 == length) @@ -1911,37 +1981,59 @@ class FunctionRightUTF8 : public IFunction return true; } - RightUTF8Impl::vectorConst(col_string->getChars(), col_string->getOffsets(), length, col_res->getChars(), col_res->getOffsets()); - } - else - { - // vector vector - auto get_length_func = [&column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); - }; - RightUTF8Impl::vectorVector(col_string->getChars(), col_string->getOffsets(), get_length_func, col_res->getChars(), col_res->getOffsets()); - } - } - else if (const ColumnConst * col_const_string = checkAndGetColumnConst(column_string.get())) - { - // const vector - const auto * col_string_from_const = checkAndGetColumn(col_const_string->getDataColumnPtr().get()); - assert(col_string_from_const); - // When useDefaultImplementationForConstants is true, string and length are not both constants - assert(!column_length->isColumnConst()); - auto get_length_func = [&column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); - }; - RightUTF8Impl::constVector(column_length->size(), col_string_from_const->getChars(), col_string_from_const->getOffsets(), get_length_func, col_res->getChars(), col_res->getOffsets()); - } - else - { - // Impossible to reach here - return false; - } - block.getByPosition(result).column = std::move(col_res); - return true; - }); + RightUTF8Impl::vectorConst( + col_string->getChars(), + col_string->getOffsets(), + length, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // vector vector + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::vectorVector( + col_string->getChars(), + col_string->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + } + else if ( + const ColumnConst * col_const_string = checkAndGetColumnConst(column_string.get())) + { + // const vector + const auto * col_string_from_const + = checkAndGetColumn(col_const_string->getDataColumnPtr().get()); + assert(col_string_from_const); + // When useDefaultImplementationForConstants is true, string and length are not both constants + assert(!column_length->isColumnConst()); + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::constVector( + column_length->size(), + col_string_from_const->getChars(), + col_string_from_const->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // Impossible to reach here + return false; + } + block.getByPosition(result).column = std::move(col_res); + return true; + }); if (!is_length_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); @@ -1953,23 +2045,14 @@ class FunctionRightUTF8 : public IFunction getLengthType(DataTypePtr type, F && f) { return castTypeToEither< - DataTypeInt64, - DataTypeUInt64>(type.get(), std::forward(f)); - } - - template - static size_t getValueFromLengthField(const Field & length_field) - { - if constexpr (std::is_same_v) - { - Int64 signed_length = length_field.get(); - return signed_length < 0 ? 0 : signed_length; - } - else - { - static_assert(std::is_same_v); - return length_field.get(); - } + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); } }; diff --git a/dbms/src/Functions/tests/gtest_string_left.cpp b/dbms/src/Functions/tests/gtest_string_left.cpp index 7ebf05be47f..b64abbc8150 100644 --- a/dbms/src/Functions/tests/gtest_string_left.cpp +++ b/dbms/src/Functions/tests/gtest_string_left.cpp @@ -66,30 +66,18 @@ class StringLeftTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } - - template - void testInvalidLengthType() - { - static_assert(!std::is_same_v && !std::is_same_v); - auto inner_test = [&](bool is_str_const, bool is_length_const) { - ASSERT_THROW( - executeFunction( - func_name, - is_str_const ? createConstColumn>(1, "") : createColumn>({""}), - is_length_const ? createConstColumn>(1, 0) : createColumn>({0})), - Exception); - }; - std::vector is_consts = {true, false}; - for (bool is_str_const : is_consts) - for (bool is_length_const : is_consts) - inner_test(is_str_const, is_length_const); - } }; TEST_F(StringLeftTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -97,6 +85,16 @@ CATCH TEST_F(StringLeftTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -104,23 +102,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = english_str + ",这是中文,C'est français,これが日本の"; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -144,18 +138,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringLeftTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH diff --git a/dbms/src/Functions/tests/gtest_strings_right.cpp b/dbms/src/Functions/tests/gtest_strings_right.cpp index 9ff8c33a4ed..e9e64bb18c4 100644 --- a/dbms/src/Functions/tests/gtest_strings_right.cpp +++ b/dbms/src/Functions/tests/gtest_strings_right.cpp @@ -65,30 +65,18 @@ class StringRightTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } - - template - void testInvalidLengthType() - { - static_assert(!std::is_same_v && !std::is_same_v); - auto inner_test = [&](bool is_str_const, bool is_length_const) { - ASSERT_THROW( - executeFunction( - func_name, - is_str_const ? createConstColumn>(1, "") : createColumn>({""}), - is_length_const ? createConstColumn>(1, 0) : createColumn>({0})), - Exception); - }; - std::vector is_consts = {true, false}; - for (bool is_str_const : is_consts) - for (bool is_length_const : is_consts) - inner_test(is_str_const, is_length_const); - } }; TEST_F(StringRightTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -96,6 +84,16 @@ CATCH TEST_F(StringRightTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -103,23 +101,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = "这是中文,C'est français,これが日本の," + english_str; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -143,18 +137,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringRightTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH diff --git a/dbms/src/Functions/tests/gtest_substring.cpp b/dbms/src/Functions/tests/gtest_substring.cpp index 9fb1273c6ee..397d9910177 100644 --- a/dbms/src/Functions/tests/gtest_substring.cpp +++ b/dbms/src/Functions/tests/gtest_substring.cpp @@ -27,9 +27,160 @@ class SubString : public DB::tests::FunctionTest { }; +template +class TestNullableSigned +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2, {}, -3}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6, 4, {}}))); + } +}; + +template +class TestSigned +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6}))); + } +}; + +template +class TestNullableUnsigned +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({11, 1, 3, 10, 8, 2, 0, 9, {}, 7}), + createColumn({4, 4, 7, 4, 5, 0, 3, 6, 1, {}}))); + } +}; + +template +class TestUnsigned +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({11, 1, 3, 10, 8, 2, 0, 2}), + createColumn({4, 4, 7, 4, 5, 0, 3, 1}))); + } +}; + +template +class TestConstPos +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"w", "ww", "w.p", ".pin"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createConstColumn(4, 1), + createColumn({1, 2, 3, 4}))); + } +}; + +template +class TestConstLength +{ +public: + static void run(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"www.", "w.pi", "ping", "ngca"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createColumn({1, 2, 3, 4}), + createConstColumn(4, 4))); + } +}; + TEST_F(SubString, subStringUTF8Test) try { + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + // column, const, const ASSERT_COLUMN_EQ( createColumn>({"www.", "ww.p", "w.pi", ".pin"}), @@ -38,6 +189,7 @@ try createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), createConstColumn>(4, 1), createConstColumn>(4, 4))); + // const, const, const ASSERT_COLUMN_EQ( createConstColumn(1, "www."), @@ -46,17 +198,8 @@ try createConstColumn>(1, "www.pingcap.com"), createConstColumn>(1, 1), createConstColumn>(1, 4))); - // Test Null - ASSERT_COLUMN_EQ( - createColumn>({{}, "www."}), - executeFunction( - "substringUTF8", - createColumn>( - {{}, "www.pingcap.com"}), - createConstColumn>(2, 1), - createConstColumn>(2, 4))); } CATCH } // namespace tests -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 69bd9c15944..c55305ac76e 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -827,6 +827,85 @@ class FunctionTest : public ::testing::Test std::unique_ptr dag_context_ptr; }; +template +struct TestTypeList +{ +}; + +using TestNullableIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestNullableUIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestIntTypes = TestTypeList; + +using TestUIntTypes = TestTypeList; + +using TestAllIntTypes + = TestTypeList, Nullable, Nullable, Nullable, Int8, Int16, Int32, Int64>; + +using TestAllUIntTypes = TestTypeList< + Nullable, + Nullable, + Nullable, + Nullable, + UInt8, + UInt16, + UInt32, + UInt64>; + +template class Func, typename FuncParam> +struct TestTypeSingle; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam & p) + { + Func::run(p); + // Recursively handle the rest of T2List + TestTypeSingle, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T2List is empty + } +}; + +template class Func, typename FuncParam> +struct TestTypePair; + +template < + typename T1, + typename... T1Rest, + typename T2List, + template + class Func, + typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam & p) + { + // For the current T1, traverse all types in T2List + TestTypeSingle::run(p); + // Recursively handle the rest of T1List + TestTypePair, T2List, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T1List is empty + } +}; + #define ASSERT_COLUMN_EQ(expected, actual) ASSERT_TRUE(DB::tests::columnEqual((expected), (actual))) #define ASSERT_BLOCK_EQ(expected, actual) ASSERT_TRUE(DB::tests::blockEqual((expected), (actual))) diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index b5bc0d837a2..c4ce7f34eb5 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -13,21 +13,25 @@ # limitations under the License. mysql> drop table if exists test.t -mysql> create table test.t(a char(10)) -mysql> insert into test.t values(''), ('abc') +mysql> create table test.t(a char(10), b int, c tinyint unsigned) +mysql> insert into test.t values('', -3, 2), ('abc', -3, 2) mysql> alter table test.t set tiflash replica 1 func> wait_table test t -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -3, 4) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 4) = 'abc' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -3, 2) = 'ab' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 2) = 'ab' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -4, 3) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, b, c) = 'ab' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -4, 3) = 'abc' # Empty mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select count(*) from test.t where substring(a, 0, 3) = '' order by a From 029747d9a07eb1576abb090059311070a563b01b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 5 Nov 2024 16:28:10 +0800 Subject: [PATCH 2/4] fix format Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 64 +++++++++++++------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 2773dad7343..3a3a61982ff 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1714,16 +1714,16 @@ class FunctionSubstringUTF8 : public IFunction return true; }); - if (!is_length_type_valid) - throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); - } + if (!is_length_type_valid) + throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); + } - // for const zero start or const zero length, return const blank string. - if (start_abs == 0 || (!implicit_length && length == 0)) - { - block.getByPosition(result).column = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); - return true; - } + // for const zero start or const zero length, return const blank string. + if (start_abs == 0 || (!implicit_length && length == 0)) + { + block.getByPosition(result).column = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); + return true; + } const auto * col = checkAndGetColumn(column_string.get()); assert(col); @@ -1798,24 +1798,24 @@ class FunctionSubstringUTF8 : public IFunction fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); } - // convert to vector if string is const. - ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string; - const auto * col = checkAndGetColumn(full_column_string.get()); - assert(col); - auto col_res = ColumnString::create(); - if (implicit_length) - { - SubstringUTF8Impl::vectorVectorVector(col->getChars(), col->getOffsets(), get_start_func, get_length_func, col_res->getChars(), col_res->getOffsets()); - } - else - { - SubstringUTF8Impl::vectorVectorVector(col->getChars(), col->getOffsets(), get_start_func, get_length_func, col_res->getChars(), col_res->getOffsets()); - } - block.getByPosition(result).column = std::move(col_res); - } + // convert to vector if string is const. + ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string; + const auto * col = checkAndGetColumn(full_column_string.get()); + assert(col); + auto col_res = ColumnString::create(); + if (implicit_length) + { + SubstringUTF8Impl::vectorVectorVector(col->getChars(), col->getOffsets(), get_start_func, get_length_func, col_res->getChars(), col_res->getOffsets()); + } + else + { + SubstringUTF8Impl::vectorVectorVector(col->getChars(), col->getOffsets(), get_start_func, get_length_func, col_res->getChars(), col_res->getOffsets()); + } + block.getByPosition(result).column = std::move(col_res); + } - return true; - }); + return true; + }); if unlikely (!is_start_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); @@ -1974,12 +1974,12 @@ class FunctionRightUTF8 : public IFunction *column_vector_length, 0); - // for const 0, return const blank string. - if (0 == length) - { - block.getByPosition(result).column = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); - return true; - } + // for const 0, return const blank string. + if (0 == length) + { + block.getByPosition(result).column = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); + return true; + } RightUTF8Impl::vectorConst( col_string->getChars(), From 70c3a2ae6d7207911383bc15dfa72e0641cff336 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 5 Nov 2024 17:30:35 +0800 Subject: [PATCH 3/4] update Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 3a3a61982ff..dd37d46b65d 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1835,16 +1835,14 @@ class FunctionSubstringUTF8 : public IFunction { Integer val = column.getElement(index); if constexpr ( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v) + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return val < 0 ? 0 : val; } else { static_assert( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v); + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v); return val; } } @@ -1876,8 +1874,7 @@ class FunctionSubstringUTF8 : public IFunction { Integer val = column.getElement(index); if constexpr ( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v) + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { if (val < 0) return {false, static_cast(-val)}; @@ -1886,8 +1883,7 @@ class FunctionSubstringUTF8 : public IFunction else { static_assert( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v); + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v); return {true, val}; } } From 52d1432b24833b72700b680375c369c227c4bf43 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 5 Nov 2024 18:11:24 +0800 Subject: [PATCH 4/4] update Signed-off-by: gengliqi --- tests/fullstack-test/expr/substring_utf8.test | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index c4ce7f34eb5..13534c00230 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -19,19 +19,19 @@ mysql> alter table test.t set tiflash replica 1 func> wait_table test t -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 4) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select a from test.t where substring(a, -3, 4) = 'abc' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 2) = 'ab' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select a from test.t where substring(a, -3, 2) = 'ab' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, b, c) = 'ab' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select a from test.t where substring(a, b, c) = 'ab' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -4, 3) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select a from test.t where substring(a, -4, 3) = 'abc' # Empty mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select count(*) from test.t where substring(a, 0, 3) = '' order by a