From 80fae98cea90f28e20694e2fcd6235f0e88c67d4 Mon Sep 17 00:00:00 2001 From: Krishna Pai Date: Wed, 16 Oct 2024 16:32:03 -0700 Subject: [PATCH] Add support for canonicalization of JSON. --- velox/functions/prestosql/JsonFunctions.cpp | 206 ++++++++++++++++-- .../prestosql/json/JsonStringUtil.cpp | 21 +- .../functions/prestosql/json/JsonStringUtil.h | 14 +- .../prestosql/tests/JsonFunctionsTest.cpp | 103 ++++++++- 4 files changed, 319 insertions(+), 25 deletions(-) diff --git a/velox/functions/prestosql/JsonFunctions.cpp b/velox/functions/prestosql/JsonFunctions.cpp index 8628a669d8285..e148a81975587 100644 --- a/velox/functions/prestosql/JsonFunctions.cpp +++ b/velox/functions/prestosql/JsonFunctions.cpp @@ -14,11 +14,85 @@ * limitations under the License. */ #include "velox/expression/VectorFunction.h" +#include "velox/functions/prestosql/json/JsonStringUtil.h" #include "velox/functions/prestosql/json/SIMDJsonUtil.h" #include "velox/functions/prestosql/types/JsonType.h" namespace facebook::velox::functions { +namespace { +const auto kArrayStart = "["_sv; +const auto kArrayEnd = "]"_sv; +const auto kSeparator = ","_sv; +const auto kObjectStart = "{"_sv; +const auto kObjectEnd = "}"_sv; +const auto kObjectKeySeparator = ":"_sv; + +/// Class to keep track of json strings being written +/// in to a buffer. The size of the backing buffer must be known during +/// construction time. +class BufferTracker { + public: + BufferTracker(size_t bufferSize, memory::MemoryPool* pool) + : curPos_(0), currentViewStart_(0) { + buffer_ = AlignedBuffer::allocate(bufferSize, pool); + bufPtr_ = buffer_->asMutable(); + } + + /// Write out all the views to the buffer. + auto getCanonicalString(std::vector& jsonViews) { + for (auto view : jsonViews) { + trimEscapeWriteToBuffer(view); + } + return getStringView(); + } + + /// Sets current view to the end of the previous string. + /// Should be called only after getCanonicalString , + /// as after this call the previous view is lost. + void startNewString() { + currentViewStart_ += curPos_; + curPos_ = 0; + } + + /// Returns the underlying buffer where the json strings are saved. + BufferPtr getUnderlyingBuffer() { + return buffer_; + } + + private: + /// Trims whitespace and escapes utf characters before writing to buffer. + void trimEscapeWriteToBuffer(StringView input) { + auto trimmed = velox::util::trimWhiteSpace(input.data(), input.size()); + auto curBufPtr = getCurrentBufferPtr(); + auto bytesWritten = + escapeString(trimmed.data(), trimmed.size(), curBufPtr, true); + incrementCounter(bytesWritten); + } + + /// Returns current string view against the buffer. + StringView getStringView() { + return StringView(bufPtr_ + currentViewStart_, curPos_); + } + + inline char* getCurrentBufferPtr() { + return bufPtr_ + currentViewStart_ + curPos_; + } + + void incrementCounter(size_t increment) { + VELOX_CHECK_LE( + curPos_ + currentViewStart_ + increment, buffer_->capacity()); + curPos_ += increment; + } + + BufferPtr buffer_; + size_t curPos_; + size_t currentViewStart_; + char* bufPtr_; +}; + +} // namespace + namespace { class JsonFormatFunction : public exec::VectorFunction { public: @@ -84,38 +158,71 @@ class JsonParseFunction : public exec::VectorFunction { auto value = arg->as>()->valueAt(0); paddedInput_.resize(value.size() + simdjson::SIMDJSON_PADDING); memcpy(paddedInput_.data(), value.data(), value.size()); - if (auto error = parse(value.size())) { + auto escapeSize = escapedStringSize(value.data(), value.size(), true); + BufferTracker bufferTracker{escapeSize, context.pool()}; + + std::vector jsonViews; + + if (auto error = parse(value.size(), jsonViews)) { context.setErrors(rows, errors_[error]); return; } - localResult = std::make_shared>( - context.pool(), rows.end(), false, JSON(), std::move(value)); + + BufferPtr stringViews = + AlignedBuffer::allocate(1, context.pool(), StringView()); + auto rawStringViews = stringViews->asMutable(); + rawStringViews[0] = bufferTracker.getCanonicalString(jsonViews); + + auto constantBase = std::make_shared>( + context.pool(), + JSON(), + nullptr, + 1, + stringViews, + std::vector{bufferTracker.getUnderlyingBuffer()}); + + localResult = BaseVector::wrapInConstant(rows.end(), 0, constantBase); + } else { auto flatInput = arg->asFlatVector(); + BufferPtr stringViews = AlignedBuffer::allocate( + rows.end(), context.pool(), StringView()); + auto rawStringViews = stringViews->asMutable(); - auto stringBuffers = flatInput->stringBuffers(); VELOX_CHECK_LE(rows.end(), flatInput->size()); size_t maxSize = 0; + size_t totalOutputSize = 0; rows.applyToSelected([&](auto row) { auto value = flatInput->valueAt(row); maxSize = std::max(maxSize, value.size()); + totalOutputSize += escapedStringSize(value.data(), value.size(), true); }); + paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING); + BufferTracker bufferTracker{totalOutputSize, context.pool()}; + rows.applyToSelected([&](auto row) { + std::vector jsonViews; auto value = flatInput->valueAt(row); memcpy(paddedInput_.data(), value.data(), value.size()); - if (auto error = parse(value.size())) { + if (auto error = parse(value.size(), jsonViews)) { context.setVeloxExceptionError(row, errors_[error]); + } else { + auto canonicalString = bufferTracker.getCanonicalString(jsonViews); + + rawStringViews[row] = canonicalString; + bufferTracker.startNewString(); } }); + localResult = std::make_shared>( context.pool(), JSON(), nullptr, rows.end(), - flatInput->values(), - std::move(stringBuffers)); + stringViews, + std::vector{bufferTracker.getUnderlyingBuffer()}); } context.moveOrCopyResult(localResult, rows, result); @@ -130,11 +237,12 @@ class JsonParseFunction : public exec::VectorFunction { } private: - simdjson::error_code parse(size_t size) const { + simdjson::error_code parse(size_t size, std::vector& jsonViews) + const { simdjson::padded_string_view paddedInput( paddedInput_.data(), size, paddedInput_.size()); SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(paddedInput)); - SIMDJSON_TRY(validate(doc)); + SIMDJSON_TRY(validate(doc, jsonViews)); if (!doc.at_end()) { return simdjson::TRAILING_CONTENT; } @@ -142,33 +250,101 @@ class JsonParseFunction : public exec::VectorFunction { } template - static simdjson::error_code validate(T value) { + static simdjson::error_code validate( + T value, + std::vector& jsonViews) { SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); switch (type) { case simdjson::ondemand::json_type::array: { SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array()); + + jsonViews.push_back(kArrayStart); for (auto elementOrError : array) { SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError); - SIMDJSON_TRY(validate(element)); + std::vector arrayElement; + SIMDJSON_TRY(validate(element, arrayElement)); + jsonViews.insert( + jsonViews.end(), + std::make_move_iterator(arrayElement.begin()), + std::make_move_iterator(arrayElement.end())); + jsonViews.push_back(kSeparator); } + + // Remove last separator. + jsonViews.pop_back(); + jsonViews.push_back(kArrayEnd); + return simdjson::SUCCESS; } + case simdjson::ondemand::json_type::object: { SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object()); + + std::vector>> objFields; for (auto fieldOrError : object) { SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError); - SIMDJSON_TRY(validate(field.value())); + auto key = StringView(field.key_raw_json_token()); + std::vector elementArray; + SIMDJSON_TRY(validate(field.value(), elementArray)); + objFields.push_back({key, elementArray}); } + + std::sort(objFields.begin(), objFields.end(), [](auto& a, auto& b) { + return a.first < b.first; + }); + + jsonViews.push_back(kObjectStart); + + for (auto i = 0; i < objFields.size(); i++) { + auto field = objFields[i]; + jsonViews.push_back(field.first); + jsonViews.push_back(kObjectKeySeparator); + + jsonViews.insert( + jsonViews.end(), + std::make_move_iterator(field.second.begin()), + std::make_move_iterator(field.second.end())); + + if (i < objFields.size() - 1) { + jsonViews.push_back(kSeparator); + } + } + + jsonViews.push_back(kObjectEnd); return simdjson::SUCCESS; } - case simdjson::ondemand::json_type::number: + + case simdjson::ondemand::json_type::number: { + SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); + + auto rawJsonv = StringView(rawJson); + jsonViews.push_back(rawJsonv); + return value.get_double().error(); - case simdjson::ondemand::json_type::string: + } + case simdjson::ondemand::json_type::string: { + SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); + + auto rawJsonv = StringView(rawJson); + jsonViews.push_back(rawJsonv); + return value.get_string().error(); - case simdjson::ondemand::json_type::boolean: + } + + case simdjson::ondemand::json_type::boolean: { + SIMDJSON_ASSIGN_OR_RAISE(auto rawJson, value.raw_json()); + + auto rawJsonv = StringView(rawJson); + jsonViews.push_back(rawJsonv); + return value.get_bool().error(); + } + case simdjson::ondemand::json_type::null: { SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null()); + auto rawJsonv = StringView(value.raw_json_token()); + + jsonViews.push_back(rawJsonv); return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR; } } diff --git a/velox/functions/prestosql/json/JsonStringUtil.cpp b/velox/functions/prestosql/json/JsonStringUtil.cpp index 43be101ec40cc..fcf641339c77a 100644 --- a/velox/functions/prestosql/json/JsonStringUtil.cpp +++ b/velox/functions/prestosql/json/JsonStringUtil.cpp @@ -108,7 +108,8 @@ void testingEncodeUtf16Hex(char32_t codePoint, char*& out) { encodeUtf16Hex(codePoint, out); } -void escapeString(const char* input, size_t length, char* output) { +size_t +escapeString(const char* input, size_t length, char* output, bool skipAscii) { char* pos = output; auto* start = reinterpret_cast(input); @@ -117,7 +118,12 @@ void escapeString(const char* input, size_t length, char* output) { int count = validateAndGetNextUtf8Length(start, end); switch (count) { case 1: { - encodeAscii(int8_t(*start), pos); + if (!skipAscii) { + encodeAscii(int8_t(*start), pos); + } else { + *pos++ = *start; + } + start++; continue; } @@ -148,9 +154,11 @@ void escapeString(const char* input, size_t length, char* output) { } } } + + return (pos - output); } -size_t escapedStringSize(const char* input, size_t length) { +size_t escapedStringSize(const char* input, size_t length, bool skipAscii) { // 6 chars that is returned by `writeHex`. constexpr size_t kEncodedHexSize = 6; @@ -162,7 +170,12 @@ size_t escapedStringSize(const char* input, size_t length) { int count = validateAndGetNextUtf8Length(start, end); switch (count) { case 1: - outSize += encodedAsciiSizes[int8_t(*start)]; + if (!skipAscii) { + outSize += encodedAsciiSizes[int8_t(*start)]; + } else { + outSize++; + } + break; case 2: case 3: diff --git a/velox/functions/prestosql/json/JsonStringUtil.h b/velox/functions/prestosql/json/JsonStringUtil.h index 65cadd86bf683..1e01d5f027a11 100644 --- a/velox/functions/prestosql/json/JsonStringUtil.h +++ b/velox/functions/prestosql/json/JsonStringUtil.h @@ -38,14 +38,24 @@ namespace facebook::velox { /// @param length: Length of the input string. /// @param output: Output string to write the escaped input to. The caller is /// responsible to allocate enough space for output. -void escapeString(const char* input, size_t length, char* output); +/// @param skipAscii: Do not consider ascii characters for encoding (used in +/// json_parse for example). +/// @return The number of bytes written to the output. +size_t escapeString( + const char* input, + size_t length, + char* output, + bool skipAscii = false); /// Return the size of string after the unicode characters of `input` are /// escaped using the method as in`escapeString`. The function will iterate /// over `input` once. /// @param input: Input string to escape that is UTF-8 encoded. /// @param length: Length of the input string. -size_t escapedStringSize(const char* input, size_t length); +/// @param skipAscii: Do not consider ascii characters for encoding (used in +/// json_parse for example). +size_t +escapedStringSize(const char* input, size_t length, bool skipAscii = false); /// For test only. Encode `codePoint` value by UTF-16 and write the one or two /// prefixed hexadecimals to `out`. Move `out` forward by 6 or 12 chars diff --git a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp index 067d374411f30..dbbda5e14a7ca 100644 --- a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "folly/Unicode.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -189,13 +190,18 @@ TEST_F(JsonFunctionsTest, jsonParse) { }; EXPECT_EQ(jsonParse(std::nullopt), std::nullopt); + // Spaces before and after. + EXPECT_EQ(jsonParse(R"( "abc" )"), R"("abc")"); EXPECT_EQ(jsonParse(R"(true)"), "true"); EXPECT_EQ(jsonParse(R"(null)"), "null"); EXPECT_EQ(jsonParse(R"(42)"), "42"); EXPECT_EQ(jsonParse(R"("abc")"), R"("abc")"); - EXPECT_EQ(jsonParse(R"([1, 2, 3])"), "[1, 2, 3]"); - EXPECT_EQ(jsonParse(R"({"k1":"v1"})"), R"({"k1":"v1"})"); - EXPECT_EQ(jsonParse(R"(["k1", "v1"])"), R"(["k1", "v1"])"); + EXPECT_EQ(jsonParse("\"abc\u4FE1\""), "\"abc\u4FE1\""); + auto utf32cp = folly::codePointToUtf8(U'😀'); + EXPECT_EQ(jsonParse(fmt::format("\"{}\"", utf32cp)), R"("\uD83D\uDE00")"); + EXPECT_EQ(jsonParse(R"([1, 2, 3])"), "[1,2,3]"); + EXPECT_EQ(jsonParse(R"({"k1": "v1" })"), R"({"k1":"v1"})"); + EXPECT_EQ(jsonParse(R"(["k1", "v1"])"), R"(["k1","v1"])"); VELOX_ASSERT_THROW( jsonParse(R"({"k1":})"), "The JSON document has an improper structure"); @@ -228,7 +234,7 @@ TEST_F(JsonFunctionsTest, jsonParse) { VELOX_ASSERT_THROW( evaluate("json_parse(c0)", data), - "Unexpected trailing content in the JSON input"); + "TAPE_ERROR: The JSON document has an improper structure: missing or superfluous commas, braces, missing keys, etc."); data = makeRowVector({makeFlatVector( {R"("This is a long sentence")", R"("This is some other sentence")"})}); @@ -276,6 +282,95 @@ TEST_F(JsonFunctionsTest, jsonParse) { } } +TEST_F(JsonFunctionsTest, canonicalization) { + const auto jsonParse = [&](std::optional value) { + return evaluateOnce("json_parse(c0)", value); + }; + + auto json = R"({ + "menu": { + "id": "file", + "value": "File", + "popup": { + "menuitem": [ + { + "value": "New", + "onclick": "CreateNewDoc() " + }, + { + "value": "Open", + "onclick": "OpenDoc() " + }, + { + "value": "Close", + "onclick": "CloseDoc() " + } + ] + } + } + })"; + + StringView expectedJson = + R"({"menu":{"id":"file","popup":{"menuitem":[{"onclick":"CreateNewDoc() ","value":"New"},{"onclick":"OpenDoc() ","value":"Open"},{"onclick":"CloseDoc() ","value":"Close"}]},"value":"File"}})"; + EXPECT_EQ(jsonParse(json), expectedJson); + + json = + "{\n" + " \"name\": \"John Doe\",\n" + " \"address\": {\n" + " \"street\": \"123 Main St\",\n" + " \"city\": \"Anytown\",\n" + " \"state\": \"CA\",\n" + " \"zip\": \"12345\"\n" + " },\n" + " \"phoneNumbers\": [\n" + " {\n" + " \"type\": \"home\",\n" + " \"number\": \"555-1234\"\n" + " },\n" + " {\n" + " \"type\": \"work\",\n" + " \"number\": \"555-5678\"\n" + " }\n" + " ],\n" + " \"familyMembers\": [\n" + " {\n" + " \"name\": \"Jane Doe\",\n" + " \"relationship\": \"wife\"\n" + " },\n" + " {\n" + " \"name\": \"Jimmy Doe\",\n" + " \"relationship\": \"son\"\n" + " }\n" + " ],\n" + " \"hobbies\": [\"golf\", \"reading\", \"traveling\"]\n" + "}"; + expectedJson = + R"({"address":{"city":"Anytown","state":"CA","street":"123 Main St","zip":"12345"},"familyMembers":[{"name":"Jane Doe","relationship":"wife"},{"name":"Jimmy Doe","relationship":"son"}],"hobbies":["golf","reading","traveling"],"name":"John Doe","phoneNumbers":[{"number":"555-1234","type":"home"},{"number":"555-5678","type":"work"}]})"; + EXPECT_EQ(jsonParse(json), expectedJson); + + // Json with spaces in keys + json = R"({ + "menu": { + "id": "file", + "value": "File", + "popup": { + "menuitem": [ + { + "value ": "New ", + "onclick": "CreateNewDoc() ", + " value ": " Space " + } + ] + } + } + })"; + + expectedJson = + R"({"menu":{"id":"file","popup":{"menuitem":[{" value ":" Space ","onclick":"CreateNewDoc() ","value ":"New "}]},"value":"File"}})"; + EXPECT_EQ(jsonParse(json), expectedJson); +} + TEST_F(JsonFunctionsTest, isJsonScalarSignatures) { auto signatures = getSignatureStrings("is_json_scalar"); ASSERT_EQ(2, signatures.size());