diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 9234c377a7fd..1a995eca5c03 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -6,18 +6,13 @@ Map Functions Returns value for given ``key``, or ``NULL`` if the key is not contained in the map. -.. spark:function:: map() -> map(unknown, unknown) +.. spark:function:: map(K, V, K, V, ...) -> map(K,V) - Returns an empty map. :: + Returns a map created using the given key/value pairs. Keys are not allowed to be null. :: - SELECT map(); -- {} + SELECT map(1, 2, 3, 4); -- {1 -> 2, 3 -> 4} -.. spark:function:: map(array(K), array(V)) -> map(K,V) - :noindex: - - Returns a map created using the given key/value arrays. Duplicate map key will cause exception. :: - - SELECT map(ARRAY[1,3], ARRAY[2,4]); -- {1 -> 2, 3 -> 4} + SELECT map(array(1, 2), array(3, 4)); -- {[1, 2] -> [3, 4]} .. spark:function:: map_filter(map(K,V), func) -> map(K,V) diff --git a/velox/functions/sparksql/Map.cpp b/velox/functions/sparksql/Map.cpp index 16656473453a..7318e6b1aa50 100644 --- a/velox/functions/sparksql/Map.cpp +++ b/velox/functions/sparksql/Map.cpp @@ -32,50 +32,27 @@ namespace facebook::velox::functions::sparksql { namespace { -template -void setKeysResultTyped( +void setKeysAndValuesResult( vector_size_t mapSize, std::vector& args, const VectorPtr& keysResult, - exec::EvalCtx& context, - const SelectivityVector& rows) { - using T = typename KindToFlatVector::WrapperType; - auto flatVector = keysResult->asFlatVector(); - flatVector->resize(rows.end() * mapSize); - - exec::LocalDecodedVector decoded(context); - for (vector_size_t i = 0; i < mapSize; i++) { - decoded.get()->decode(*args[i * 2], rows); - // For efficiency traverse one arg at the time - rows.applyToSelected([&](vector_size_t row) { - VELOX_CHECK(!decoded->isNullAt(row), "Cannot use null as map key!"); - flatVector->set(row * mapSize + i, decoded->valueAt(row)); - }); - } -} - -template -void setValuesResultTyped( - vector_size_t mapSize, - std::vector& args, const VectorPtr& valuesResult, exec::EvalCtx& context, const SelectivityVector& rows) { - using T = typename KindToFlatVector::WrapperType; - auto flatVector = valuesResult->asFlatVector(); - flatVector->resize(rows.end() * mapSize); - exec::LocalDecodedVector decoded(context); + SelectivityVector targetRows(keysResult->size(), false); + std::vector toSourceRow(keysResult->size()); for (vector_size_t i = 0; i < mapSize; i++) { - decoded.get()->decode(*args[i * 2 + 1], rows); - - rows.applyToSelected([&](vector_size_t row) { - if (decoded->isNullAt(row)) { - flatVector->setNull(row * mapSize + i, true); - } else { - flatVector->set(row * mapSize + i, decoded->valueAt(row)); - } + decoded.get()->decode(*args[i * 2], rows); + context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { + VELOX_USER_CHECK(!decoded->isNullAt(row), "Cannot use null as map key!"); + targetRows.setValid(row * mapSize + i, true); + toSourceRow[row * mapSize + i] = row; }); + targetRows.updateBounds(); + keysResult->copy(args[i * 2].get(), targetRows, toSourceRow.data()); + valuesResult->copy(args[i * 2 + 1].get(), targetRows, toSourceRow.data()); + targetRows.clearAll(); } } @@ -91,7 +68,7 @@ class MapFunction : public exec::VectorFunction { const TypePtr& /*outputType*/, exec::EvalCtx& context, VectorPtr& result) const override { - VELOX_CHECK( + VELOX_USER_CHECK( args.size() >= 2 && args.size() % 2 == 0, "Map function must take an even number of arguments"); auto mapSize = args.size() / 2; @@ -101,14 +78,12 @@ class MapFunction : public exec::VectorFunction { // Check key and value types for (auto i = 0; i < mapSize; i++) { - VELOX_CHECK_EQ( - args[i * 2]->type(), - keyType, - "All the key arguments in Map function must be the same!"); - VELOX_CHECK_EQ( - args[i * 2 + 1]->type(), - valueType, + VELOX_USER_CHECK( + args[i * 2]->type()->equivalent(*keyType), "All the key arguments in Map function must be the same!"); + VELOX_USER_CHECK( + args[i * 2 + 1]->type()->equivalent(*valueType), + "All the value arguments in Map function must be the same!"); } // Initializing input @@ -130,26 +105,12 @@ class MapFunction : public exec::VectorFunction { // Setting keys and value elements auto keysResult = mapResult->mapKeys(); auto valuesResult = mapResult->mapValues(); - keysResult->resize(rows.end() * mapSize); - valuesResult->resize(rows.end() * mapSize); - - VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( - setKeysResultTyped, - keyType->kind(), - mapSize, - args, - keysResult, - context, - rows); + const auto resultSize = rows.end() * mapSize; + keysResult->resize(resultSize); + valuesResult->resize(resultSize); - VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( - setValuesResultTyped, - valueType->kind(), - mapSize, - args, - valuesResult, - context, - rows); + setKeysAndValuesResult( + mapSize, args, keysResult, valuesResult, context, rows); } static std::vector> signatures() { diff --git a/velox/functions/sparksql/tests/MapTest.cpp b/velox/functions/sparksql/tests/MapTest.cpp index b8446fff1739..b005837af1c9 100644 --- a/velox/functions/sparksql/tests/MapTest.cpp +++ b/velox/functions/sparksql/tests/MapTest.cpp @@ -14,8 +14,11 @@ * limitations under the License. */ +#include "velox/common/base/VeloxException.h" +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Variant.h" #include @@ -25,19 +28,20 @@ namespace { class MapTest : public SparkFunctionBaseTest { protected: template - void mapSimple( + void testMap( const std::string& expression, const std::vector& parameters, - bool expectException = false, - const VectorPtr& expected = nullptr) { - if (expectException) { - ASSERT_THROW( - evaluate(expression, makeRowVector(parameters)), - std::exception); - } else { - auto result = evaluate(expression, makeRowVector(parameters)); - ::facebook::velox::test::assertEqualVectors(result, expected); - } + const VectorPtr& expected) { + auto result = evaluate(expression, makeRowVector(parameters)); + ::facebook::velox::test::assertEqualVectors(expected, result); + } + + void testMapFails( + const std::string& expression, + const std::vector& parameters, + const std::string errorMsg) { + VELOX_ASSERT_USER_THROW( + evaluate(expression, makeRowVector(parameters)), errorMsg); } }; @@ -46,7 +50,7 @@ TEST_F(MapTest, Basics) { auto inputVector2 = makeNullableFlatVector({4, 5, 6}); auto mapVector = makeMapVector({{{1, 4}}, {{2, 5}}, {{3, 6}}}); - mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector); + testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector); } TEST_F(MapTest, Nulls) { @@ -55,55 +59,157 @@ TEST_F(MapTest, Nulls) { makeNullableFlatVector({std::nullopt, 5, std::nullopt}); auto mapVector = makeMapVector( {{{1, std::nullopt}}, {{2, 5}}, {{3, std::nullopt}}}); - mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector); + testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector); } -TEST_F(MapTest, DifferentTypes) { +TEST_F(MapTest, differentTypes) { auto inputVector1 = makeNullableFlatVector({1, 2, 3}); auto inputVector2 = makeNullableFlatVector({4.0, 5.0, 6.0}); auto mapVector = makeMapVector({{{1, 4.0}}, {{2, 5.0}}, {{3, 6.0}}}); - mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector); + testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector); } TEST_F(MapTest, boolType) { auto inputVector1 = makeNullableFlatVector({1, 1, 0}); auto inputVector2 = makeNullableFlatVector({0, 0, 1}); auto mapVector = makeMapVector({{{1, 0}}, {{1, 0}}, {{0, 1}}}); - mapSimple("map(c0, c1)", {inputVector1, inputVector2}, false, mapVector); + testMap("map(c0, c1)", {inputVector1, inputVector2}, mapVector); } -TEST_F(MapTest, Wide) { +TEST_F(MapTest, wide) { auto inputVector1 = makeNullableFlatVector({1, 2, 3}); auto inputVector2 = makeNullableFlatVector({4.0, 5.0, 6.0}); auto inputVector11 = makeNullableFlatVector({10, 20, 30}); auto inputVector22 = makeNullableFlatVector({4.1, 5.1, 6.1}); auto mapVector = makeMapVector( {{{1, 4.0}, {10, 4.1}}, {{2, 5.0}, {20, 5.1}}, {{3, 6.0}, {30, 6.1}}}); - mapSimple( + testMap( "map(c0, c1, c2, c3)", {inputVector1, inputVector2, inputVector11, inputVector22}, - false, mapVector); } -TEST_F(MapTest, ErrorCases) { +TEST_F(MapTest, errorCases) { + auto inputVectorInt64 = makeNullableFlatVector({1, 2, 3}); + auto inputVectorDouble = makeNullableFlatVector({4.0, 5.0, 6.0}); + auto nullInputVector = makeNullableFlatVector({1, std::nullopt, 3}); + // Number of args - auto inputVector1 = makeNullableFlatVector({1, 2, 3}); - auto inputVector2 = makeNullableFlatVector({4.0, 5.0, 6.0}); - auto mapVector = - makeMapVector({{{1, 4.0}}, {{2, 5.0}}, {{3, 6.0}}}); - mapSimple("map(c0)", {inputVector1}, true); - mapSimple( - "map(c0, c1, c2)", {inputVector1, inputVector2, inputVector1}, true); + testMapFails( + "map(c0)", + {inputVectorInt64}, + "Scalar function signature is not supported: map(BIGINT)"); + testMapFails( + "map(c0, c1, c2)", + {inputVectorInt64, inputVectorDouble, inputVectorInt64}, + "Scalar function signature is not supported: map(BIGINT, DOUBLE, BIGINT)"); + + testMapFails( + "map(c0, c1, c2, c3, c4, c5, c6, c7)", + {inputVectorDouble, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble}, + "Scalar function signature is not supported: map(DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE, DOUBLE)"); // Types of args - auto inputVector11 = makeNullableFlatVector({10.0, 20.0, 30.0}); - auto inputVector22 = makeNullableFlatVector({4.1, 5.1, 6.1}); - mapSimple( + testMapFails( "map(c0, c1, c2, c3)", - {inputVector1, inputVector2, inputVector11, inputVector22}, - true); + {inputVectorInt64, + inputVectorDouble, + inputVectorDouble, + inputVectorDouble}, + "Scalar function signature is not supported: map(BIGINT, DOUBLE, DOUBLE, DOUBLE)"); + testMapFails( + "map(c0, c1, c2, c3)", + {inputVectorDouble, + inputVectorInt64, + inputVectorDouble, + inputVectorDouble}, + "Scalar function signature is not supported: map(DOUBLE, BIGINT, DOUBLE, DOUBLE)"); + + testMapFails( + "map(c0, c1)", + {nullInputVector, inputVectorDouble}, + "Cannot use null as map key"); +} + +TEST_F(MapTest, complexTypes) { + auto makeSingleMapVector = [&](const VectorPtr& keyVector, + const VectorPtr& valueVector) { + return makeMapVector( + { + 0, + }, + keyVector, + valueVector); + }; + + auto makeSingleRowVector = [&](vector_size_t size = 1, + vector_size_t base = 0) { + return makeRowVector({ + makeFlatVector(size, [&](auto row) { return row + base; }), + }); + }; + + auto testSingleMap = [&](const VectorPtr& keyVector, + const VectorPtr& valueVector) { + testMap( + "map(c0, c1)", + {keyVector, valueVector}, + makeSingleMapVector(keyVector, valueVector)); + }; + + auto arrayKey = makeArrayVectorFromJson({"[1, 2, 3]"}); + auto arrayValue = makeArrayVectorFromJson({"[1, 3, 5]"}); + auto nullArrayValue = makeArrayVectorFromJson({"null"}); + + testSingleMap(makeSingleRowVector(), makeSingleRowVector(1, 2)); + + testSingleMap(arrayKey, arrayValue); + + testSingleMap( + makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector(1, 3)), + makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector(1, 2))); + + testSingleMap( + makeSingleMapVector( + makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector()), + makeSingleRowVector()), + makeSingleMapVector( + arrayKey, + makeSingleMapVector(makeSingleRowVector(), makeSingleRowVector()))); + + testSingleMap(arrayKey, nullArrayValue); + + auto mixedArrayKey1 = makeArrayVector({{1, 2, 3}}); + auto mixedRowValue1 = makeSingleRowVector(); + auto mixedArrayKey2 = makeArrayVector({{4, 5}}); + auto mixedRowValue2 = makeSingleRowVector(1, 1); + auto mixedMapResult = makeSingleMapVector( + makeArrayVector({{1, 2, 3}, {4, 5}}), makeSingleRowVector(2, 0)); + testMap( + "map(c0, c1, c2, c3)", + {mixedArrayKey1, mixedRowValue1, mixedArrayKey2, mixedRowValue2}, + mixedMapResult); + + auto arrayMapResult1 = makeMapVector( + { + 0, + 1, + }, + makeArrayVector({{1, 2, 3}, {7, 9}}), + makeArrayVector({{1, 2}, {4, 6}})); + testMap( + "map(c0, c1)", + {makeArrayVector({{1, 2, 3}, {7, 9}}), + makeArrayVector({{1, 2}, {4, 6}})}, + arrayMapResult1); } } // namespace } // namespace facebook::velox::functions::sparksql::test