Skip to content

Commit

Permalink
Hash-based implementation for innerJoin-kernel (#926)
Browse files Browse the repository at this point in the history
- The existing innerJoin-kernel was very inefficient as it was a nested-loop-join and incurred a significant function call overhead per value. Furthermore, it suffered from several memory leaks related to the use of Frame::getColumn().
- This commit replaces the innerJoin-kernel implementation by a hash-based one.
- Adapted a unit test case of the innerJoin-kernel to trigger a case where the build side of the join is not unique.
  • Loading branch information
saminbassiri authored Dec 12, 2024
1 parent a374814 commit bd1f04b
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 74 deletions.
166 changes: 101 additions & 65 deletions src/runtime/local/kernels/InnerJoin.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
* Copyright 2024 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H
#define SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H

Expand All @@ -10,7 +26,8 @@

#include <stdexcept>
#include <tuple>
#include <unordered_set>
#include <unordered_map>
#include <vector>

#include <cstddef>
#include <cstdint>
Expand All @@ -30,43 +47,85 @@ template <typename VTCol>
void innerJoinSet(ValueTypeCode vtcType, Frame *&res, const Frame *arg, const int64_t toRow, const int64_t toCol,
const int64_t fromRow, const int64_t fromCol, DCTX(ctx)) {
if (vtcType == ValueTypeUtils::codeFor<VTCol>) {
innerJoinSetValue<VTCol>(res->getColumn<VTCol>(toCol), arg->getColumn<VTCol>(fromCol), toRow, fromRow, ctx);
const DenseMatrix<VTCol> *colArg = arg->getColumn<VTCol>(fromCol);
DenseMatrix<VTCol> *colRes = res->getColumn<VTCol>(toCol);
innerJoinSetValue<VTCol>(colRes, colArg, toRow, fromRow, ctx);
DataObjectFactory::destroy(colArg, colRes);
}
}

template <typename VTLhs, typename VTRhs>
bool innerJoinEqual(
// results
Frame *&res,
// arguments
const DenseMatrix<VTLhs> *argLhs, const DenseMatrix<VTRhs> *argRhs, const int64_t targetLhs,
const int64_t targetRhs,
// context
DCTX(ctx)) {
const VTLhs l = argLhs->get(targetLhs, 0);
const VTRhs r = argRhs->get(targetRhs, 0);
return l == r;
// Create a hash table for rhs
template <typename VTRhs>
std::unordered_map<VTRhs, std::vector<size_t>> BuildHashRhs(const Frame *rhs, const char *rhsOn,
const size_t numRowRhs) {
std::unordered_map<VTRhs, std::vector<size_t>> res;
const DenseMatrix<VTRhs> *col = rhs->getColumn<VTRhs>(rhsOn);
for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) {
VTRhs key = col->get(row_idx_r, 0);
res[key].push_back(row_idx_r);
}
DataObjectFactory::destroy(col);
return res;
}

template <typename VTLhs, typename VTRhs>
bool innerJoinProbeIf(
// value type known only at run-time
ValueTypeCode vtcLhs, ValueTypeCode vtcRhs,
// results
Frame *&res,
template <typename VT>
int64_t ProbeHashLhs(
// results and results schema
Frame *&res, ValueTypeCode *schema,
// input frames
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// input rows
const int64_t targetL, const int64_t targetR,
const char *lhsOn,
// num columns
const size_t numColRhs, const size_t numColLhs,
// context
DCTX(ctx)) {
if (vtcLhs == ValueTypeUtils::codeFor<VTLhs> && vtcRhs == ValueTypeUtils::codeFor<VTRhs>) {
return innerJoinEqual<VTLhs, VTRhs>(res, lhs->getColumn<VTLhs>(lhsOn), rhs->getColumn<VTRhs>(rhsOn), targetL,
targetR, ctx);
DCTX(ctx),
// hashed map of Rhs
std::unordered_map<VT, std::vector<size_t>> hashRhsIndex,
// Lhs rowa
const size_t numRowLhs) {
int64_t row_idx_res = 0;
int64_t col_idx_res = 0;
auto lhsFKCol = lhs->getColumn<VT>(lhsOn);
for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) {
auto key = lhsFKCol->get(row_idx_l, 0);
auto it = hashRhsIndex.find(key);

if (it != hashRhsIndex.end()) {
for (size_t row_idx_r : it->second) {
col_idx_res = 0;

// Populate result row from lhs columns
for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) {
innerJoinSet<std::string>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<int64_t>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<double>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);

col_idx_res++;
}

// Populate result row from rhs columns
for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) {
innerJoinSet<std::string>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);
innerJoinSet<int64_t>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

innerJoinSet<double>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

col_idx_res++;
}

row_idx_res++;
}
}
}
return false;
DataObjectFactory::destroy(lhsFKCol);
return row_idx_res;
}

// ****************************************************************************
Expand All @@ -84,10 +143,8 @@ inline void innerJoin(
int64_t numRowRes,
// context
DCTX(ctx)) {

// Find out the value types of the columns to process.
ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn);
ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn);

// Perhaps check if res already allocated.
const size_t numRowRhs = rhs->getNumRows();
Expand All @@ -100,55 +157,34 @@ inline void innerJoin(
const std::string *oldlabels_r = rhs->getLabels();

int64_t col_idx_res = 0;
int64_t row_idx_res = 0;
int64_t row_idx_res;

// Set up schema and labels
ValueTypeCode schema[totalCols];
std::string newlabels[totalCols];

// Setting Schema and Labels
for (size_t col_idx_l = 0; col_idx_l < numColLhs; col_idx_l++) {
schema[col_idx_res] = lhs->getColumnType(col_idx_l);
newlabels[col_idx_res] = oldlabels_l[col_idx_l];
col_idx_res++;
newlabels[col_idx_res++] = oldlabels_l[col_idx_l];
}
for (size_t col_idx_r = 0; col_idx_r < numColRhs; col_idx_r++) {
schema[col_idx_res] = rhs->getColumnType(col_idx_r);
newlabels[col_idx_res] = oldlabels_r[col_idx_r];
col_idx_res++;
newlabels[col_idx_res++] = oldlabels_r[col_idx_r];
}

// Creating Result Frame
// Initialize result frame with an estimate
res = DataObjectFactory::create<Frame>(totalRows, totalCols, schema, newlabels, false);

for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) {
for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) {
col_idx_res = 0;
// PROBE ROWS
bool hit = false;
hit = hit || innerJoinProbeIf<int64_t, int64_t>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l,
row_idx_r, ctx);
hit = hit || innerJoinProbeIf<double, double>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l,
row_idx_r, ctx);
if (hit) {
for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) {
innerJoinSet<int64_t>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<double>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
col_idx_res++;
}
for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) {
innerJoinSet<int64_t>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

innerJoinSet<double>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);
col_idx_res++;
}
row_idx_res++;
}
}
// Build hash table and prob left table
if (vtcLhsOn == ValueTypeCode::STR) {
row_idx_res = ProbeHashLhs<std::string>(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx,
BuildHashRhs<std::string>(rhs, rhsOn, numRowRhs), numRowLhs);
} else {
row_idx_res = ProbeHashLhs<int64_t>(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx,
BuildHashRhs<int64_t>(rhs, rhsOn, numRowRhs), numRowLhs);
}
// Shrink result frame to actual size
res->shrinkNumRows(row_idx_res);
}

#endif // SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H
18 changes: 9 additions & 9 deletions test/runtime/local/kernels/InnerJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ TEST_CASE("InnerJoin", TAG_KERNELS) {
std::string lhsLabels[] = {"a", "b"};
auto lhs = DataObjectFactory::create<Frame>(lhsCols, lhsLabels);

auto rhsC0 = genGivenVals<DenseMatrix<int64_t>>(3, {1, 4, 5});
auto rhsC1 = genGivenVals<DenseMatrix<int64_t>>(3, {-1, -4, -5});
auto rhsC2 = genGivenVals<DenseMatrix<double>>(3, {0.1, 0.2, 0.3});
auto rhsC0 = genGivenVals<DenseMatrix<int64_t>>(4, {1, 4, 5, 4});
auto rhsC1 = genGivenVals<DenseMatrix<int64_t>>(4, {-1, -4, -5, -6});
auto rhsC2 = genGivenVals<DenseMatrix<double>>(4, {0.1, 0.2, 0.3, 0.4});
std::vector<Structure *> rhsCols = {rhsC0, rhsC1, rhsC2};
std::string rhsLabels[] = {"c", "d", "e"};
auto rhs = DataObjectFactory::create<Frame>(rhsCols, rhsLabels);
Expand All @@ -49,7 +49,7 @@ TEST_CASE("InnerJoin", TAG_KERNELS) {
innerJoin(res, lhs, rhs, "a", "c", -1, nullptr);

// Check the meta data.
CHECK(res->getNumRows() == 2);
CHECK(res->getNumRows() == 3);
CHECK(res->getNumCols() == 5);

CHECK(res->getColumnType(0) == ValueTypeCode::SI64);
Expand All @@ -64,11 +64,11 @@ TEST_CASE("InnerJoin", TAG_KERNELS) {
CHECK(res->getLabels()[3] == "d");
CHECK(res->getLabels()[4] == "e");

auto resC0Exp = genGivenVals<DenseMatrix<int64_t>>(2, {1, 4});
auto resC1Exp = genGivenVals<DenseMatrix<double>>(2, {11.0, 44.0});
auto resC2Exp = genGivenVals<DenseMatrix<int64_t>>(2, {1, 4});
auto resC3Exp = genGivenVals<DenseMatrix<int64_t>>(2, {-1, -4});
auto resC4Exp = genGivenVals<DenseMatrix<double>>(2, {0.1, 0.2});
auto resC0Exp = genGivenVals<DenseMatrix<int64_t>>(3, {1, 4, 4});
auto resC1Exp = genGivenVals<DenseMatrix<double>>(3, {11.0, 44.0, 44.0});
auto resC2Exp = genGivenVals<DenseMatrix<int64_t>>(3, {1, 4, 4});
auto resC3Exp = genGivenVals<DenseMatrix<int64_t>>(3, {-1, -4, -6});
auto resC4Exp = genGivenVals<DenseMatrix<double>>(3, {0.1, 0.2, 0.4});

CHECK(*(res->getColumn<int64_t>(0)) == *resC0Exp);
CHECK(*(res->getColumn<double>(1)) == *resC1Exp);
Expand Down

0 comments on commit bd1f04b

Please sign in to comment.