Skip to content

Commit

Permalink
String value type support in more kernels.
Browse files Browse the repository at this point in the history
- EwBinaryOpCode: Updated specification which elementwise binary ops are supported on which string value types (std::string, FixedStr16, const char *).
  - Needed to generalize some macros from one common value type for both operands to individual value types, e.g., to support "matrix of std::string op string scalar".
- Fixed a small type bug in EwBinaryObjSca-kernel.
- ExtractRow- and FilterRow-kernels on frames work with frames with std::string columns.
- Added instantiations of EwBinaryObjSca-kernel for comparisons of string matrices with string scalar (integer result).
- Order- and Group-kernels can handle frames with std::string columns.
- These changes were originally included in PRs #918, #921, and #926 by @saminbassiri.
- @pdamme is committing them in the name of @saminbassiri (for correct attribution) in a separate commit, since they don't really fit the topics of those PRs.
  • Loading branch information
saminbassiri authored and pdamme committed Dec 12, 2024
1 parent 3afd47a commit a374814
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 63 deletions.
39 changes: 24 additions & 15 deletions src/runtime/local/kernels/BinaryOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,20 @@ static constexpr bool supportsBinaryOp = false;
SUPPORT(BITWISE_AND, VT)

// Generates code specifying that all binary operations of a certain category should be
// supported on the given argument value type `VTArg` (for the left and right-hand-side
// arguments, for simplicity) and the given result value type `VTRes`.
#define SUPPORT_COMPARISONS_RA(VTRes, VTArg) \
// supported on the given argument value types `VTLhs` and `VTRhs` (for the left and right-hand-side
// arguments, respectively) and the given result value type `VTRes`.
#define SUPPORT_COMPARISONS_RLR(VTRes, VTLhs, VTRhs) \
/* string Comparisons operations. */ \
SUPPORT_RLR(LT, VTRes, VTArg, VTArg) \
SUPPORT_RLR(GT, VTRes, VTArg, VTArg)
#define SUPPORT_EQUALITY_RA(VTRes, VTArg) \
SUPPORT_RLR(LT, VTRes, VTLhs, VTRhs) \
SUPPORT_RLR(GT, VTRes, VTLhs, VTRhs)
#define SUPPORT_COMPARISONS_EQUAL_RLR(VTRes, VTLhs, VTRhs) \
/* string Comparisons operations. */ \
SUPPORT_RLR(EQ, VTRes, VTArg, VTArg) \
SUPPORT_RLR(NEQ, VTRes, VTArg, VTArg)
SUPPORT_RLR(LE, VTRes, VTLhs, VTRhs) \
SUPPORT_RLR(GE, VTRes, VTLhs, VTRhs)
#define SUPPORT_EQUALITY_RLR(VTRes, VTLhs, VTRhs) \
/* string Comparisons operations. */ \
SUPPORT_RLR(EQ, VTRes, VTLhs, VTRhs) \
SUPPORT_RLR(NEQ, VTRes, VTLhs, VTRhs)
#define SUPPORT_STRING_RLR(VTRes, VTLhs, VTRhs) \
/* string concatenation operations. */ \
/* Since the result may not fit in FixedStr16,*/ \
Expand Down Expand Up @@ -175,11 +179,15 @@ SUPPORT_NUMERIC_INT(uint64_t)
SUPPORT_NUMERIC_INT(uint32_t)
SUPPORT_NUMERIC_INT(uint8_t)
// Strings binary operations.
SUPPORT_EQUALITY_RA(int64_t, std::string)
SUPPORT_EQUALITY_RA(int64_t, FixedStr16)
SUPPORT_EQUALITY_RA(int64_t, const char *)
SUPPORT_COMPARISONS_RA(int64_t, std::string)
SUPPORT_COMPARISONS_RA(int64_t, FixedStr16)
SUPPORT_EQUALITY_RLR(int64_t, std::string, std::string)
SUPPORT_EQUALITY_RLR(int64_t, FixedStr16, FixedStr16)
SUPPORT_EQUALITY_RLR(int64_t, const char *, const char *)
SUPPORT_EQUALITY_RLR(int64_t, std::string, const char *)
SUPPORT_COMPARISONS_RLR(int64_t, std::string, std::string)
SUPPORT_COMPARISONS_RLR(int64_t, FixedStr16, FixedStr16)
SUPPORT_COMPARISONS_RLR(int64_t, std::string, const char *)
SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, std::string)
SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, const char *)
SUPPORT_STRING_RLR(std::string, std::string, std::string)
SUPPORT_STRING_RLR(std::string, FixedStr16, FixedStr16)
SUPPORT_STRING_RLR(const char *, const char *, const char *)
Expand All @@ -195,6 +203,7 @@ SUPPORT_STRING_RLR(std::string, std::string, const char *)
#undef SUPPORT_BITWISE
#undef SUPPORT_NUMERIC_FP
#undef SUPPORT_NUMERIC_INT
#undef SUPPORT_EQUALITY_RA
#undef SUPPORT_COMPARISONS_RA
#undef SUPPORT_EQUALITY_RLR
#undef SUPPORT_COMPARISONS_RLR
#undef SUPPORT_COMPARISONS_EQUAL_RLR
#undef SUPPORT_STRING_RLR
2 changes: 1 addition & 1 deletion src/runtime/local/kernels/EwBinaryObjSca.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct EwBinaryObjSca<DenseMatrix<VTRes>, DenseMatrix<VTLhs>, VTRhs> {
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, numCols, false);

const VTRes *valuesLhs = lhs->getValues();
const VTLhs *valuesLhs = lhs->getValues();
VTRes *valuesRes = res->getValues();

EwBinaryScaFuncPtr<VTRes, VTLhs, VTRhs> func = getEwBinaryScaFuncPtr<VTRes, VTLhs, VTRhs>(opCode);
Expand Down
27 changes: 17 additions & 10 deletions src/runtime/local/kernels/ExtractRow.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,23 @@ template <typename VTSel> struct ExtractRow<Frame, Frame, VTSel> {
throw std::out_of_range(errMsg.str());
}
for (size_t c = 0; c < numCols; c++) {
// We always copy in units of 8 bytes (uint64_t). If the
// actual element size is lower, the superfluous bytes will
// be overwritten by the next match. With this approach, we
// do not need to call memcpy for each element, nor
// interpret the types for a L/S of fitting size.
// TODO Don't multiply by elementSize, but left-shift by
// ld(elementSize).
*reinterpret_cast<uint64_t *>(resCols[c]) =
*reinterpret_cast<const uint64_t *>(argCols[c] + pos * elementSizes[c]);
resCols[c] += elementSizes[c];
if (schema[c] == ValueTypeCode::STR) {
// Handle std::string column
*reinterpret_cast<std::string *>(resCols[c]) =
*reinterpret_cast<const std::string *>(argCols[c] + pos * elementSizes[c]);
resCols[c] += elementSizes[c];
} else {
// We always copy in units of 8 bytes (uint64_t). If the
// actual element size is lower, the superfluous bytes will
// be overwritten by the next match. With this approach, we
// do not need to call memcpy for each element, nor
// interpret the types for a L/S of fitting size.
// TODO Don't multiply by elementSize, but left-shift by
// ld(elementSize).
*reinterpret_cast<uint64_t *>(resCols[c]) =
*reinterpret_cast<const uint64_t *>(argCols[c] + pos * elementSizes[c]);
resCols[c] += elementSizes[c];
}
}
}
res->shrinkNumRows(numRowsSel);
Expand Down
21 changes: 14 additions & 7 deletions src/runtime/local/kernels/FilterRow.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,20 @@ template <typename VTSel> struct FilterRow<Frame, Frame, VTSel> {
for (size_t r = 0; r < numRows; r++) {
if (valuesSel[r]) {
for (size_t c = 0; c < numCols; c++) {
// We always copy in units of 8 bytes (uint64_t). If the
// actual element size is lower, the superfluous bytes will
// be overwritten by the next match. With this approach, we
// do not need to call memcpy for each element, nor
// interpret the types for a L/S of fitting size.
*reinterpret_cast<uint64_t *>(resCols[c]) = *reinterpret_cast<const uint64_t *>(argCols[c]);
resCols[c] += elementSizes[c];
if (schema[c] == ValueTypeCode::STR) {
// Handle std::string column
*reinterpret_cast<std::string *>(resCols[c]) =
*reinterpret_cast<const std::string *>(argCols[c]); // Deep copy the string
resCols[c] += elementSizes[c];
} else {
// We always copy in units of 8 bytes (uint64_t). If the
// actual element size is lower, the superfluous bytes will
// be overwritten by the next match. With this approach, we
// do not need to call memcpy for each element, nor
// interpret the types for a L/S of fitting size.
*reinterpret_cast<uint64_t *>(resCols[c]) = *reinterpret_cast<const uint64_t *>(argCols[c]);
resCols[c] += elementSizes[c];
}
}
}
for (size_t c = 0; c < numCols; c++)
Expand Down
101 changes: 77 additions & 24 deletions src/runtime/local/kernels/Group.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,85 @@ void group(DT *&res, const DT *arg, const char **keyCols, size_t numKeyCols, con
// Frame <- Frame
// ----------------------------------------------------------------------------

// TODO If possible, reuse the stringifyGroupEnum() from the DAPHNE compiler.
inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) {
using mlir::daphne::GroupEnum;
switch (val) {
case GroupEnum::COUNT:
return "COUNT";
case GroupEnum::SUM:
return "SUM";
case GroupEnum::MIN:
return "MIN";
case GroupEnum::MAX:
return "MAX";
case GroupEnum::AVG:
return "AVG";
}
throw std::runtime_error("invalid GroupEnum value");
}

// returns the result of the aggregation function aggFunc over the (contiguous)
// memory between the begin and end pointer
template <typename VTRes, typename VTArg>
VTRes aggregate(const mlir::daphne::GroupEnum &aggFunc, const VTArg *begin, const VTArg *end) {
using mlir::daphne::GroupEnum;
switch (aggFunc) {
case GroupEnum::COUNT:
return end - begin;
if constexpr (std::is_same<VTRes, std::string>::value)
throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) +
std::string(" aggregation is not supported for these value types."));
else
return end - begin;
break; // TODO: Do we need to check for Null elements here?
case GroupEnum::SUM:
return std::accumulate(begin, end, (VTRes)0);
if constexpr ((std::is_same<VTRes, std::string>::value) || (std::is_same<VTArg, std::string>::value))
throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) +
std::string(" aggregation is not supported for these value types."));
else
return std::accumulate(begin, end, (VTRes)0);
break;
case GroupEnum::MIN:
return *std::min_element(begin, end);
if constexpr ((std::is_same<VTRes, std::string>::value) || (std::is_same<VTArg, std::string>::value))
throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) +
std::string(" aggregation is not supported for these value types."));
else
return *std::min_element(begin, end);
break;
case GroupEnum::MAX:
return *std::max_element(begin, end);
if constexpr ((std::is_same<VTRes, std::string>::value) || (std::is_same<VTArg, std::string>::value))
throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) +
std::string(" aggregation is not supported for these value types."));
else
return *std::max_element(begin, end);
break;
case GroupEnum::AVG:
return std::accumulate(begin, end, (double)0) / (double)(end - begin);
if constexpr ((std::is_same<VTRes, std::string>::value) || (std::is_same<VTArg, std::string>::value))
throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) +
std::string(" aggregation is not supported for these value types."));
else
return std::accumulate(begin, end, (double)0) / (double)(end - begin);
break;
default:
return *begin;
if constexpr (std::is_same<VTArg, std::string>::value || std::is_same<VTRes, std::string>::value)
throw std::invalid_argument("aggregate: Unsupported aggregation operation for string types.");
else
return *begin;
break;
}
}

template <>
std::string aggregate(const mlir::daphne::GroupEnum &aggFunc, const std::string *begin, const std::string *end) {
using mlir::daphne::GroupEnum;
if (aggFunc == GroupEnum::MIN)
return *std::min_element(begin, end);
if (aggFunc == GroupEnum::MAX)
return *std::max_element(begin, end);
else
return *begin;
}

// struct which calls the aggregate() function (specified via aggFunc) on each
// duplicate group in the groups vector and on all implied single groups for a
// sepcified column (colIdx) of the argument frame (arg) and stores the result
Expand Down Expand Up @@ -117,22 +169,14 @@ template <typename VTRes, typename VTArg> struct ColumnGroupAgg {
}
};

inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) {
using mlir::daphne::GroupEnum;
switch (val) {
case GroupEnum::COUNT:
return "COUNT";
case GroupEnum::SUM:
return "SUM";
case GroupEnum::MIN:
return "MIN";
case GroupEnum::MAX:
return "MAX";
case GroupEnum::AVG:
return "AVG";
// Since DeduceValueTypeAndExecute can not handle string values,
// we add special ColumnGroupAgg function for arg with std::string values.
template <typename VTRes> struct ColumnGroupAggStringVTArg {
static void apply(Frame *res, const Frame *arg, size_t colIdx, std::vector<std::pair<size_t, size_t>> *groups,
mlir::daphne::GroupEnum aggFunc, DCTX(ctx)) {
ColumnGroupAgg<VTRes, std::string>::apply(res, arg, colIdx, groups, aggFunc, ctx);
}
return "";
}
};

template <> struct Group<Frame> {
static void apply(Frame *&res, const Frame *arg, const char **keyCols, size_t numKeyCols, const char **aggCols,
Expand Down Expand Up @@ -270,9 +314,18 @@ template <> struct Group<Frame> {

// copying key columns and column-wise group aggregation
for (size_t i = 0; i < numColsRes; i++) {
DeduceValueTypeAndExecute<ColumnGroupAgg>::apply(
res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups,
(i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx);
if (ordered->getSchema()[i] == ValueTypeCode::STR) {
if (res->getSchema()[i] == ValueTypeCode::STR)
ColumnGroupAgg<std::string, std::string>::apply(
res, ordered, i, groups, (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx);
else
DeduceValueTypeAndExecute<ColumnGroupAggStringVTArg>::apply(
res->getSchema()[i], res, ordered, i, groups,
(i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx);
} else
DeduceValueTypeAndExecute<ColumnGroupAgg>::apply(
res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups,
(i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx);
}
delete groups;
DataObjectFactory::destroy(ordered);
Expand Down
21 changes: 15 additions & 6 deletions src/runtime/local/kernels/Order.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,29 @@ struct OrderFrame {

if (numColIdxs > 1) {
for (size_t i = 0; i < numColIdxs - 1; i++) {
DeduceValueTypeAndExecute<MultiColumnIDSort>::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups,
ascending[i], colIdxs[i], ctx);
if (arg->getSchema()[colIdxs[i]] == ValueTypeCode::STR)
MultiColumnIDSort<std::string>::apply(arg, idx, groups, ascending[i], colIdxs[i], ctx);
else
DeduceValueTypeAndExecute<MultiColumnIDSort>::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups,
ascending[i], colIdxs[i], ctx);
}
}

// efficient last sort pass OR finalizing the groups vector for further
// use
size_t colIdx = colIdxs[numColIdxs - 1];
if (groupsRes == nullptr) {
DeduceValueTypeAndExecute<ColumnIDSort>::apply(arg->getSchema()[colIdx], arg, idx, groups,
ascending[numColIdxs - 1], colIdx, ctx);
if (arg->getSchema()[colIdx] == ValueTypeCode::STR)
ColumnIDSort<std::string>::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx);
else
DeduceValueTypeAndExecute<ColumnIDSort>::apply(arg->getSchema()[colIdx], arg, idx, groups,
ascending[numColIdxs - 1], colIdx, ctx);
} else {
DeduceValueTypeAndExecute<MultiColumnIDSort>::apply(arg->getSchema()[colIdx], arg, idx, groups,
ascending[numColIdxs - 1], colIdx, ctx);
if (arg->getSchema()[colIdx] == ValueTypeCode::STR)
MultiColumnIDSort<std::string>::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx);
else
DeduceValueTypeAndExecute<MultiColumnIDSort>::apply(arg->getSchema()[colIdx], arg, idx, groups,
ascending[numColIdxs - 1], colIdx, ctx);
groupsRes->insert(groupsRes->end(), groups.begin(), groups.end());
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,11 @@
["DenseMatrix", "std::string"],
["DenseMatrix", "std::string"],
"const char *"
],
[
["DenseMatrix", "int64_t"],
["DenseMatrix", "std::string"],
"const char *"
]
],
"opCodes": [
Expand Down

0 comments on commit a374814

Please sign in to comment.