Skip to content

Commit

Permalink
Add tests for SliceCol and SliceRow kernel functions for Dense Matric…
Browse files Browse the repository at this point in the history
…es of Strings.
  • Loading branch information
saminbassiri committed Oct 11, 2024
1 parent 49a2cc0 commit f3b7504
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
17 changes: 17 additions & 0 deletions test/runtime/local/kernels/SliceColTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ TEMPLATE_PRODUCT_TEST_CASE("SliceCol", TAG_KERNELS, (DenseMatrix, Matrix), (doub
DataObjectFactory::destroy(arg, exp, res);
}

TEMPLATE_PRODUCT_TEST_CASE("SliceCol", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) {
using DT = TestType;
using VT = typename DT::VT;

std::vector<VT> vals = {VT("a"), VT(""), VT("1"), VT("abc"), VT("e"), VT("j"), VT("abc"), VT("abcd"),
VT("ab"), VT("a"), VT("f"), VT("k"), VT("ABC"), VT("34ab"), VT("ac"), VT("b"),
VT("g"), VT("l"), VT("cd"), VT(" "), VT("ad"), VT("c"), VT("h"), VT(" ")};
std::vector<VT> valsExp = {VT(""), VT("1"), VT("abcd"), VT("ab"), VT("34ab"), VT("ac"), VT(" "), VT("ad")};
auto arg = genGivenVals<DT>(4, vals);
auto exp = genGivenVals<DT>(4, valsExp);
DT *res = nullptr;
sliceCol(res, arg, 1, 3, nullptr);
CHECK(*res == *exp);

DataObjectFactory::destroy(arg, exp, res);
}

TEMPLATE_PRODUCT_TEST_CASE("SliceCol - check throws", TAG_KERNELS, (DenseMatrix, Matrix), (double, int64_t)) {
using DT = TestType;
using VT = typename DT::VT;
Expand Down
18 changes: 18 additions & 0 deletions test/runtime/local/kernels/SliceRowTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ TEMPLATE_PRODUCT_TEST_CASE("SliceRow", TAG_KERNELS, (DenseMatrix, Matrix), (doub
DataObjectFactory::destroy(arg, exp, res);
}

TEMPLATE_PRODUCT_TEST_CASE("SliceRow", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) {
using DT = TestType;
using VT = typename DT::VT;

std::vector<VT> vals = {VT("a"), VT(""), VT("1"), VT("abc"), VT("e"), VT("j"), VT("abc"), VT("abcd"),
VT(" "), VT("a"), VT("f"), VT("k"), VT("ABC"), VT("34ab"), VT("ac"), VT("b"),
VT("g"), VT("l"), VT("cd"), VT(" "), VT("ad"), VT("c"), VT("h"), VT(" ")};
std::vector<VT> valsExp = {VT("abc"), VT("abcd"), VT(" "), VT("a"), VT("f"), VT("k"),
VT("ABC"), VT("34ab"), VT("ac"), VT("b"), VT("g"), VT("l")};
auto arg = genGivenVals<DT>(4, vals);
auto exp = genGivenVals<DT>(2, valsExp);
DT *res = nullptr;
sliceRow(res, arg, 1, 3, nullptr);
CHECK(*res == *exp);

DataObjectFactory::destroy(arg, exp, res);
}

TEMPLATE_PRODUCT_TEST_CASE("SliceRow - check throws", TAG_KERNELS, (DenseMatrix, Matrix), (double, int64_t)) {
using DT = TestType;
using VT = typename DT::VT;
Expand Down

0 comments on commit f3b7504

Please sign in to comment.