Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support other integer types for SubstringUTF8 & RightUTF8 functions #9507

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 124 additions & 49 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1681,26 +1681,41 @@ class FunctionSubstringUTF8 : public IFunction
bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
// Int64 / UInt64
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs]
= getValueFromStartField<StartFieldType>((*block.getByPosition(arguments[1]).column)[0]);
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
length = getValueFromLengthField<LengthFieldType>(
(*block.getByPosition(arguments[2]).column)[0]);
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});

Expand Down Expand Up @@ -1735,15 +1750,15 @@ class FunctionSubstringUTF8 : public IFunction
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartField<StartFieldType>((*column_start)[0]);
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [&column_start](size_t i) {
return getValueFromStartField<StartFieldType>((*column_start)[i]);
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

Expand All @@ -1756,26 +1771,36 @@ class FunctionSubstringUTF8 : public IFunction
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if (!is_length_type_valid)
if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
Expand Down Expand Up @@ -1813,7 +1838,7 @@ class FunctionSubstringUTF8 : public IFunction
return true;
});

if (!is_start_type_valid)
if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

Expand Down Expand Up @@ -1841,48 +1866,67 @@ class FunctionSubstringUTF8 : public IFunction
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
return val < 0 ? 0 : val;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = start_field.get<Int64>();

if (signed_length < 0)
{
return {false, static_cast<size_t>(-signed_length)};
}
else
{
return {true, static_cast<size_t>(signed_length)};
}
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return {true, start_field.get<UInt64>()};
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
yibin87 marked this conversation as resolved.
Show resolved Hide resolved
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}
};

Expand Down Expand Up @@ -1921,16 +1965,26 @@ class FunctionRightUTF8 : public IFunction
bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
size_t length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);

// for const 0, return const blank string.
if (0 == length)
Expand All @@ -1950,8 +2004,8 @@ class FunctionRightUTF8 : public IFunction
else
{
// vector vector
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
Expand All @@ -1970,8 +2024,8 @@ class FunctionRightUTF8 : public IFunction
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
RightUTF8Impl::constVector(
column_length->size(),
Expand All @@ -1998,21 +2052,42 @@ class FunctionRightUTF8 : public IFunction
template <typename F>
static bool getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
return val < 0 ? 0 : val;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}
};
Expand Down
Loading