diff --git a/src/runtime/local/kernels/InnerJoin.h b/src/runtime/local/kernels/InnerJoin.h index 55451a4ba..e0c8be1a3 100644 --- a/src/runtime/local/kernels/InnerJoin.h +++ b/src/runtime/local/kernels/InnerJoin.h @@ -1,3 +1,19 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H #define SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H @@ -10,7 +26,8 @@ #include #include -#include +#include +#include #include #include @@ -30,43 +47,85 @@ template void innerJoinSet(ValueTypeCode vtcType, Frame *&res, const Frame *arg, const int64_t toRow, const int64_t toCol, const int64_t fromRow, const int64_t fromCol, DCTX(ctx)) { if (vtcType == ValueTypeUtils::codeFor) { - innerJoinSetValue(res->getColumn(toCol), arg->getColumn(fromCol), toRow, fromRow, ctx); + const DenseMatrix *colArg = arg->getColumn(fromCol); + DenseMatrix *colRes = res->getColumn(toCol); + innerJoinSetValue(colRes, colArg, toRow, fromRow, ctx); + DataObjectFactory::destroy(colArg, colRes); } } -template -bool innerJoinEqual( - // results - Frame *&res, - // arguments - const DenseMatrix *argLhs, const DenseMatrix *argRhs, const int64_t targetLhs, - const int64_t targetRhs, - // context - DCTX(ctx)) { - const VTLhs l = argLhs->get(targetLhs, 0); - const VTRhs r = argRhs->get(targetRhs, 0); - return l == r; +// Create a hash table for rhs +template +std::unordered_map> BuildHashRhs(const Frame *rhs, const char *rhsOn, + const size_t numRowRhs) { + std::unordered_map> res; + const DenseMatrix *col = rhs->getColumn(rhsOn); + for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { + VTRhs key = col->get(row_idx_r, 0); + res[key].push_back(row_idx_r); + } + DataObjectFactory::destroy(col); + return res; } -template -bool innerJoinProbeIf( - // value type known only at run-time - ValueTypeCode vtcLhs, ValueTypeCode vtcRhs, - // results - Frame *&res, +template +int64_t ProbeHashLhs( + // results and results schema + Frame *&res, ValueTypeCode *schema, // input frames const Frame *lhs, const Frame *rhs, // input column names - const char *lhsOn, const char *rhsOn, - // input rows - const int64_t targetL, const int64_t targetR, + const char *lhsOn, + // num columns + const size_t numColRhs, const size_t numColLhs, // context - DCTX(ctx)) { - if (vtcLhs == ValueTypeUtils::codeFor && vtcRhs == ValueTypeUtils::codeFor) { - return innerJoinEqual(res, lhs->getColumn(lhsOn), rhs->getColumn(rhsOn), targetL, - targetR, ctx); + DCTX(ctx), + // hashed map of Rhs + std::unordered_map> hashRhsIndex, + // Lhs rowa + const size_t numRowLhs) { + int64_t row_idx_res = 0; + int64_t col_idx_res = 0; + auto lhsFKCol = lhs->getColumn(lhsOn); + for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { + auto key = lhsFKCol->get(row_idx_l, 0); + auto it = hashRhsIndex.find(key); + + if (it != hashRhsIndex.end()) { + for (size_t row_idx_r : it->second) { + col_idx_res = 0; + + // Populate result row from lhs columns + for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + + col_idx_res++; + } + + // Populate result row from rhs columns + for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + col_idx_res++; + } + + row_idx_res++; + } + } } - return false; + DataObjectFactory::destroy(lhsFKCol); + return row_idx_res; } // **************************************************************************** @@ -84,10 +143,8 @@ inline void innerJoin( int64_t numRowRes, // context DCTX(ctx)) { - // Find out the value types of the columns to process. ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn); - ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn); // Perhaps check if res already allocated. const size_t numRowRhs = rhs->getNumRows(); @@ -100,55 +157,34 @@ inline void innerJoin( const std::string *oldlabels_r = rhs->getLabels(); int64_t col_idx_res = 0; - int64_t row_idx_res = 0; + int64_t row_idx_res; + // Set up schema and labels ValueTypeCode schema[totalCols]; std::string newlabels[totalCols]; - // Setting Schema and Labels for (size_t col_idx_l = 0; col_idx_l < numColLhs; col_idx_l++) { schema[col_idx_res] = lhs->getColumnType(col_idx_l); - newlabels[col_idx_res] = oldlabels_l[col_idx_l]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_l[col_idx_l]; } for (size_t col_idx_r = 0; col_idx_r < numColRhs; col_idx_r++) { schema[col_idx_res] = rhs->getColumnType(col_idx_r); - newlabels[col_idx_res] = oldlabels_r[col_idx_r]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_r[col_idx_r]; } - // Creating Result Frame + // Initialize result frame with an estimate res = DataObjectFactory::create(totalRows, totalCols, schema, newlabels, false); - for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { - for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { - col_idx_res = 0; - // PROBE ROWS - bool hit = false; - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - if (hit) { - for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - col_idx_res++; - } - for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - col_idx_res++; - } - row_idx_res++; - } - } + // Build hash table and prob left table + if (vtcLhsOn == ValueTypeCode::STR) { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); + } else { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); } + // Shrink result frame to actual size res->shrinkNumRows(row_idx_res); } + #endif // SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H diff --git a/test/runtime/local/kernels/InnerJoinTest.cpp b/test/runtime/local/kernels/InnerJoinTest.cpp index 6e4c264c1..c2ee660f0 100644 --- a/test/runtime/local/kernels/InnerJoinTest.cpp +++ b/test/runtime/local/kernels/InnerJoinTest.cpp @@ -38,9 +38,9 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { std::string lhsLabels[] = {"a", "b"}; auto lhs = DataObjectFactory::create(lhsCols, lhsLabels); - auto rhsC0 = genGivenVals>(3, {1, 4, 5}); - auto rhsC1 = genGivenVals>(3, {-1, -4, -5}); - auto rhsC2 = genGivenVals>(3, {0.1, 0.2, 0.3}); + auto rhsC0 = genGivenVals>(4, {1, 4, 5, 4}); + auto rhsC1 = genGivenVals>(4, {-1, -4, -5, -6}); + auto rhsC2 = genGivenVals>(4, {0.1, 0.2, 0.3, 0.4}); std::vector rhsCols = {rhsC0, rhsC1, rhsC2}; std::string rhsLabels[] = {"c", "d", "e"}; auto rhs = DataObjectFactory::create(rhsCols, rhsLabels); @@ -49,7 +49,7 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { innerJoin(res, lhs, rhs, "a", "c", -1, nullptr); // Check the meta data. - CHECK(res->getNumRows() == 2); + CHECK(res->getNumRows() == 3); CHECK(res->getNumCols() == 5); CHECK(res->getColumnType(0) == ValueTypeCode::SI64); @@ -64,11 +64,11 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { CHECK(res->getLabels()[3] == "d"); CHECK(res->getLabels()[4] == "e"); - auto resC0Exp = genGivenVals>(2, {1, 4}); - auto resC1Exp = genGivenVals>(2, {11.0, 44.0}); - auto resC2Exp = genGivenVals>(2, {1, 4}); - auto resC3Exp = genGivenVals>(2, {-1, -4}); - auto resC4Exp = genGivenVals>(2, {0.1, 0.2}); + auto resC0Exp = genGivenVals>(3, {1, 4, 4}); + auto resC1Exp = genGivenVals>(3, {11.0, 44.0, 44.0}); + auto resC2Exp = genGivenVals>(3, {1, 4, 4}); + auto resC3Exp = genGivenVals>(3, {-1, -4, -6}); + auto resC4Exp = genGivenVals>(3, {0.1, 0.2, 0.4}); CHECK(*(res->getColumn(0)) == *resC0Exp); CHECK(*(res->getColumn(1)) == *resC1Exp);