Skip to content

Commit

Permalink
Add complex type support for map Spark function (#7560)
Browse files Browse the repository at this point in the history
Summary:
1. Add complex type support for map Spark function.
2. Remove map() function from spark function doc.

Fixes: #7559

Pull Request resolved: #7560

Reviewed By: xiaoxmeng, pedroerp

Differential Revision: D51527569

Pulled By: mbasmanova

fbshipit-source-id: d37c5e5fef9fcc9223dd7e340752cefdf05cd4fe
  • Loading branch information
zhli1142015 authored and facebook-github-bot committed Nov 23, 2023
1 parent 46798f9 commit 2b97fed
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 103 deletions.
13 changes: 4 additions & 9 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@ Map Functions
Returns value for given ``key``, or ``NULL`` if the key is not contained in the map.

.. spark:function:: map() -> map(unknown, unknown)
.. spark:function:: map(K, V, K, V, ...) -> map(K,V)
Returns an empty map. ::
Returns a map created using the given key/value pairs. Keys are not allowed to be null. ::

SELECT map(); -- {}
SELECT map(1, 2, 3, 4); -- {1 -> 2, 3 -> 4}

.. spark:function:: map(array(K), array(V)) -> map(K,V)
:noindex:

Returns a map created using the given key/value arrays. Duplicate map key will cause exception. ::

SELECT map(ARRAY[1,3], ARRAY[2,4]); -- {1 -> 2, 3 -> 4}
SELECT map(array(1, 2), array(3, 4)); -- {[1, 2] -> [3, 4]}

.. spark:function:: map_filter(map(K,V), func) -> map(K,V)
Expand Down
85 changes: 23 additions & 62 deletions velox/functions/sparksql/Map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,27 @@
namespace facebook::velox::functions::sparksql {
namespace {

template <TypeKind kind>
void setKeysResultTyped(
void setKeysAndValuesResult(
vector_size_t mapSize,
std::vector<VectorPtr>& args,
const VectorPtr& keysResult,
exec::EvalCtx& context,
const SelectivityVector& rows) {
using T = typename KindToFlatVector<kind>::WrapperType;
auto flatVector = keysResult->asFlatVector<T>();
flatVector->resize(rows.end() * mapSize);

exec::LocalDecodedVector decoded(context);
for (vector_size_t i = 0; i < mapSize; i++) {
decoded.get()->decode(*args[i * 2], rows);
// For efficiency traverse one arg at the time
rows.applyToSelected([&](vector_size_t row) {
VELOX_CHECK(!decoded->isNullAt(row), "Cannot use null as map key!");
flatVector->set(row * mapSize + i, decoded->valueAt<T>(row));
});
}
}

template <TypeKind kind>
void setValuesResultTyped(
vector_size_t mapSize,
std::vector<VectorPtr>& args,
const VectorPtr& valuesResult,
exec::EvalCtx& context,
const SelectivityVector& rows) {
using T = typename KindToFlatVector<kind>::WrapperType;
auto flatVector = valuesResult->asFlatVector<T>();
flatVector->resize(rows.end() * mapSize);

exec::LocalDecodedVector decoded(context);
SelectivityVector targetRows(keysResult->size(), false);
std::vector<vector_size_t> toSourceRow(keysResult->size());
for (vector_size_t i = 0; i < mapSize; i++) {
decoded.get()->decode(*args[i * 2 + 1], rows);

rows.applyToSelected([&](vector_size_t row) {
if (decoded->isNullAt(row)) {
flatVector->setNull(row * mapSize + i, true);
} else {
flatVector->set(row * mapSize + i, decoded->valueAt<T>(row));
}
decoded.get()->decode(*args[i * 2], rows);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
VELOX_USER_CHECK(!decoded->isNullAt(row), "Cannot use null as map key!");
targetRows.setValid(row * mapSize + i, true);
toSourceRow[row * mapSize + i] = row;
});
targetRows.updateBounds();
keysResult->copy(args[i * 2].get(), targetRows, toSourceRow.data());
valuesResult->copy(args[i * 2 + 1].get(), targetRows, toSourceRow.data());
targetRows.clearAll();
}
}

Expand All @@ -91,7 +68,7 @@ class MapFunction : public exec::VectorFunction {
const TypePtr& /*outputType*/,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK(
VELOX_USER_CHECK(
args.size() >= 2 && args.size() % 2 == 0,
"Map function must take an even number of arguments");
auto mapSize = args.size() / 2;
Expand All @@ -101,14 +78,12 @@ class MapFunction : public exec::VectorFunction {

// Check key and value types
for (auto i = 0; i < mapSize; i++) {
VELOX_CHECK_EQ(
args[i * 2]->type(),
keyType,
"All the key arguments in Map function must be the same!");
VELOX_CHECK_EQ(
args[i * 2 + 1]->type(),
valueType,
VELOX_USER_CHECK(
args[i * 2]->type()->equivalent(*keyType),
"All the key arguments in Map function must be the same!");
VELOX_USER_CHECK(
args[i * 2 + 1]->type()->equivalent(*valueType),
"All the value arguments in Map function must be the same!");
}

// Initializing input
Expand All @@ -130,26 +105,12 @@ class MapFunction : public exec::VectorFunction {
// Setting keys and value elements
auto keysResult = mapResult->mapKeys();
auto valuesResult = mapResult->mapValues();
keysResult->resize(rows.end() * mapSize);
valuesResult->resize(rows.end() * mapSize);

VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
setKeysResultTyped,
keyType->kind(),
mapSize,
args,
keysResult,
context,
rows);
const auto resultSize = rows.end() * mapSize;
keysResult->resize(resultSize);
valuesResult->resize(resultSize);

VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
setValuesResultTyped,
valueType->kind(),
mapSize,
args,
valuesResult,
context,
rows);
setKeysAndValuesResult(
mapSize, args, keysResult, valuesResult, context, rows);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
Expand Down
170 changes: 138 additions & 32 deletions velox/functions/sparksql/tests/MapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
* limitations under the License.
*/

#include "velox/common/base/VeloxException.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/type/Variant.h"

#include <stdint.h>

Expand All @@ -25,19 +28,20 @@ namespace {
class MapTest : public SparkFunctionBaseTest {
protected:
template <typename K = int64_t, typename V = std::string>
void mapSimple(
void testMap(
const std::string& expression,
const std::vector<VectorPtr>& parameters,
bool expectException = false,
const VectorPtr& expected = nullptr) {
if (expectException) {
ASSERT_THROW(
evaluate<MapVector>(expression, makeRowVector(parameters)),
std::exception);
} else {
auto result = evaluate<MapVector>(expression, makeRowVector(parameters));
::facebook::velox::test::assertEqualVectors(result, expected);
}
const VectorPtr& expected) {
auto result = evaluate<MapVector>(expression, makeRowVector(parameters));
::facebook::velox::test::assertEqualVectors(expected, result);
}

void testMapFails(
const std::string& expression,
const std::vector<VectorPtr>& parameters,
const std::string errorMsg) {
VELOX_ASSERT_USER_THROW(
evaluate<MapVector>(expression, makeRowVector(parameters)), errorMsg);
}
};

Expand All @@ -46,7 +50,7 @@ TEST_F(MapTest, Basics) {
auto inputVector2 = makeNullableFlatVector<int64_t>({4, 5, 6});
auto mapVector =
makeMapVector<int64_t, int64_t>({{{1, 4}}, {{2, 5}}, {{3, 6}}});
mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector);
testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector);
}

TEST_F(MapTest, Nulls) {
Expand All @@ -55,55 +59,157 @@ TEST_F(MapTest, Nulls) {
makeNullableFlatVector<int64_t>({std::nullopt, 5, std::nullopt});
auto mapVector = makeMapVector<int64_t, int64_t>(
{{{1, std::nullopt}}, {{2, 5}}, {{3, std::nullopt}}});
mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector);
testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector);
}

TEST_F(MapTest, DifferentTypes) {
TEST_F(MapTest, differentTypes) {
auto inputVector1 = makeNullableFlatVector<int64_t>({1, 2, 3});
auto inputVector2 = makeNullableFlatVector<double>({4.0, 5.0, 6.0});
auto mapVector =
makeMapVector<int64_t, double>({{{1, 4.0}}, {{2, 5.0}}, {{3, 6.0}}});
mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector);
testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector);
}

TEST_F(MapTest, boolType) {
auto inputVector1 = makeNullableFlatVector<bool>({1, 1, 0});
auto inputVector2 = makeNullableFlatVector<bool>({0, 0, 1});
auto mapVector = makeMapVector<bool, bool>({{{1, 0}}, {{1, 0}}, {{0, 1}}});
mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector);
testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector);
}

TEST_F(MapTest, Wide) {
TEST_F(MapTest, wide) {
auto inputVector1 = makeNullableFlatVector<int64_t>({1, 2, 3});
auto inputVector2 = makeNullableFlatVector<double>({4.0, 5.0, 6.0});
auto inputVector11 = makeNullableFlatVector<int64_t>({10, 20, 30});
auto inputVector22 = makeNullableFlatVector<double>({4.1, 5.1, 6.1});
auto mapVector = makeMapVector<int64_t, double>(
{{{1, 4.0}, {10, 4.1}}, {{2, 5.0}, {20, 5.1}}, {{3, 6.0}, {30, 6.1}}});
mapSimple(
testMap(
"map(c0, c1, c2, c3)",
{inputVector1, inputVector2, inputVector11, inputVector22},
false,
mapVector);
}

TEST_F(MapTest, ErrorCases) {
TEST_F(MapTest, errorCases) {
auto inputVectorInt64 = makeNullableFlatVector<int64_t>({1, 2, 3});
auto inputVectorDouble = makeNullableFlatVector<double>({4.0, 5.0, 6.0});
auto nullInputVector = makeNullableFlatVector<int64_t>({1, std::nullopt, 3});

// Number of args
auto inputVector1 = makeNullableFlatVector<int64_t>({1, 2, 3});
auto inputVector2 = makeNullableFlatVector<double>({4.0, 5.0, 6.0});
auto mapVector =
makeMapVector<int64_t, double>({{{1, 4.0}}, {{2, 5.0}}, {{3, 6.0}}});
mapSimple("map(c0)", {inputVector1}, true);
mapSimple(
"map(c0, c1, c2)", {inputVector1, inputVector2, inputVector1}, true);
testMapFails(
"map(c0)",
{inputVectorInt64},
"Scalar function signature is not supported: map(BIGINT)");
testMapFails(
"map(c0, c1, c2)",
{inputVectorInt64, inputVectorDouble, inputVectorInt64},
"Scalar function signature is not supported: map(BIGINT, DOUBLE, BIGINT)");

testMapFails(
"map(c0, c1, c2, c3, c4, c5, c6, c7)",
{inputVectorDouble,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble},
"Scalar function signature is not supported: map(DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE)");

// Types of args
auto inputVector11 = makeNullableFlatVector<double>({10.0, 20.0, 30.0});
auto inputVector22 = makeNullableFlatVector<double>({4.1, 5.1, 6.1});
mapSimple(
testMapFails(
"map(c0, c1, c2, c3)",
{inputVector1, inputVector2, inputVector11, inputVector22},
true);
{inputVectorInt64,
inputVectorDouble,
inputVectorDouble,
inputVectorDouble},
"Scalar function signature is not supported: map(BIGINT, DOUBLE, DOUBLE, DOUBLE)");
testMapFails(
"map(c0, c1, c2, c3)",
{inputVectorDouble,
inputVectorInt64,
inputVectorDouble,
inputVectorDouble},
"Scalar function signature is not supported: map(DOUBLE, BIGINT, DOUBLE, DOUBLE)");

testMapFails(
"map(c0, c1)",
{nullInputVector, inputVectorDouble},
"Cannot use null as map key");
}

TEST_F(MapTest, complexTypes) {
auto makeSingleMapVector = [&](const VectorPtr& keyVector,
const VectorPtr& valueVector) {
return makeMapVector(
{
0,
},
keyVector,
valueVector);
};

auto makeSingleRowVector = [&](vector_size_t size = 1,
vector_size_t base = 0) {
return makeRowVector({
makeFlatVector<int64_t>(size, [&](auto row) { return row + base; }),
});
};

auto testSingleMap = [&](const VectorPtr& keyVector,
const VectorPtr& valueVector) {
testMap(
"map(c0, c1)",
{keyVector, valueVector},
makeSingleMapVector(keyVector, valueVector));
};

auto arrayKey = makeArrayVectorFromJson<int64_t>({"[1, 2, 3]"});
auto arrayValue = makeArrayVectorFromJson<int64_t>({"[1, 3, 5]"});
auto nullArrayValue = makeArrayVectorFromJson<int64_t>({"null"});

testSingleMap(makeSingleRowVector(), makeSingleRowVector(1, 2));

testSingleMap(arrayKey, arrayValue);

testSingleMap(
makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector(1, 3)),
makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector(1, 2)));

testSingleMap(
makeSingleMapVector(
makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector()),
makeSingleRowVector()),
makeSingleMapVector(
arrayKey,
makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector())));

testSingleMap(arrayKey, nullArrayValue);

auto mixedArrayKey1 = makeArrayVector<int64_t>({{1, 2, 3}});
auto mixedRowValue1 = makeSingleRowVector();
auto mixedArrayKey2 = makeArrayVector<int64_t>({{4, 5}});
auto mixedRowValue2 = makeSingleRowVector(1, 1);
auto mixedMapResult = makeSingleMapVector(
makeArrayVector<int64_t>({{1, 2, 3}, {4, 5}}), makeSingleRowVector(2, 0));
testMap(
"map(c0, c1, c2, c3)",
{mixedArrayKey1, mixedRowValue1, mixedArrayKey2, mixedRowValue2},
mixedMapResult);

auto arrayMapResult1 = makeMapVector(
{
0,
1,
},
makeArrayVector<int64_t>({{1, 2, 3}, {7, 9}}),
makeArrayVector<int64_t>({{1, 2}, {4, 6}}));
testMap(
"map(c0, c1)",
{makeArrayVector<int64_t>({{1, 2, 3}, {7, 9}}),
makeArrayVector<int64_t>({{1, 2}, {4, 6}})},
arrayMapResult1);
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 2b97fed

Please sign in to comment.