diff --git a/cases/plan/error_unsupport_sql.yaml b/cases/plan/error_unsupport_sql.yaml index f308c05599d..3071108a41d 100644 --- a/cases/plan/error_unsupport_sql.yaml +++ b/cases/plan/error_unsupport_sql.yaml @@ -56,11 +56,6 @@ cases: desc: 表路径层级超过边界2 sql: | select db.t1.level3.* from t; - - id: 9 - desc: Insert 非常量 - mode: request-unsupport - sql: | - insert into t1 values(1, 2, aaa); - id: in_predicate_subquery desc: test_expr in subquery sql: | diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index 48ed1705038..320454bc431 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -574,7 +574,7 @@ cases: map('1', 2, '1', 4, '1', 6, '7', 8, '9', 10, '11', 12)['1'] as e9, # map("c", 99, "d", NULL)["d"] as e10, expect: - # FIXME + # FIXME(someone): add e10 result core dump occasionally on centOS columns: ["e1 string", "e2 int", "e3 string", "e4 int", "e5 string", "e6 timestamp", "e7 int", "e8 int", "e9 int"] data: | 2, 100, NULL, 101, f, 2000, 10, NULL, 2 diff --git a/hybridse/include/codec/fe_schema_codec.h b/hybridse/include/codec/fe_schema_codec.h index df8642de8fa..02c03c886ee 100644 --- a/hybridse/include/codec/fe_schema_codec.h +++ b/hybridse/include/codec/fe_schema_codec.h @@ -18,10 +18,7 @@ #define HYBRIDSE_INCLUDE_CODEC_FE_SCHEMA_CODEC_H_ #include -#include -#include #include -#include #include "vm/catalog.h" namespace hybridse { @@ -56,7 +53,7 @@ class SchemaCodec { if (it->name().size() >= 128) { return false; } - uint8_t name_size = (uint8_t)(it->name().size()); + uint8_t name_size = static_cast(it->name().size()); memcpy(cbuffer, static_cast(&name_size), 1); cbuffer += 1; memcpy(cbuffer, static_cast(it->name().c_str()), @@ -66,7 +63,7 @@ class SchemaCodec { return true; } - static bool Decode(const std::string& buf, vm::Schema* schema) { + static bool Decode(const std::string& buf, codec::Schema* schema) { if (schema == NULL) return false; if (buf.size() <= 0) return true; const char* buffer = buf.c_str(); diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index d127fccc71a..ee81c222714 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -1900,6 +1900,9 @@ class ColumnDefNode : public SqlNode { std::string GetColumnName() const { return column_name_; } + const ColumnSchemaNode *schema() const { return schema_; } + + // deprecated, use ColumnDefNode::schema instead DataType GetColumnType() const { return schema_->type(); } const ExprNode* GetDefaultValue() const { return schema_->default_value(); } diff --git a/hybridse/include/sdk/base_impl.h b/hybridse/include/sdk/base_impl.h index 5d1fd8bc842..524c41f5f0c 100644 --- a/hybridse/include/sdk/base_impl.h +++ b/hybridse/include/sdk/base_impl.h @@ -30,13 +30,13 @@ typedef ::google::protobuf::RepeatedPtrField< ::hybridse::type::TableDef> class SchemaImpl : public Schema { public: - explicit SchemaImpl(const vm::Schema& schema); + explicit SchemaImpl(const codec::Schema& schema); SchemaImpl() {} ~SchemaImpl(); - const vm::Schema& GetSchema() const { return schema_; } - inline void SetSchema(const vm::Schema& schema) { schema_ = schema; } + const codec::Schema& GetSchema() const { return schema_; } + inline void SetSchema(const codec::Schema& schema) { schema_ = schema; } int32_t GetColumnCnt() const; const std::string& GetColumnName(uint32_t index) const; @@ -46,7 +46,7 @@ class SchemaImpl : public Schema { const bool IsConstant(uint32_t index) const; private: - vm::Schema schema_; + codec::Schema schema_; }; class TableImpl : public Table { diff --git a/hybridse/src/codegen/block_ir_builder.cc b/hybridse/src/codegen/block_ir_builder.cc index 818229553ca..200a8f9f732 100644 --- a/hybridse/src/codegen/block_ir_builder.cc +++ b/hybridse/src/codegen/block_ir_builder.cc @@ -290,16 +290,18 @@ bool BlockIRBuilder::BuildReturnStmt(const ::hybridse::node::FnReturnStmt *node, } ::llvm::Value *value = value_wrapper.GetValue(&builder); if (TypeIRBuilder::IsStructPtr(value->getType())) { - StructTypeIRBuilder *struct_builder = - StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), - value->getType()); + auto struct_builder = StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), value->getType()); + if (!struct_builder.ok()) { + status.code = kCodegenError; + status.msg = struct_builder.status().ToString(); + return false; + } NativeValue ret_value; if (!var_ir_builder.LoadRetStruct(&ret_value, status)) { LOG(WARNING) << "fail to load ret struct address"; return false; } - if (!struct_builder->CopyFrom(block, value, - ret_value.GetValue(&builder))) { + if (!struct_builder.value()->CopyFrom(block, value, ret_value.GetValue(&builder))) { return false; } value = builder.getInt1(true); diff --git a/hybridse/src/codegen/buf_ir_builder.cc b/hybridse/src/codegen/buf_ir_builder.cc index 79d2c4aef96..db5ca6a2c5b 100644 --- a/hybridse/src/codegen/buf_ir_builder.cc +++ b/hybridse/src/codegen/buf_ir_builder.cc @@ -276,7 +276,7 @@ bool BufNativeIRBuilder::BuildGetStringField(uint32_t col_idx, uint32_t offset, BufNativeEncoderIRBuilder::BufNativeEncoderIRBuilder(CodeGenContextBase* ctx, const std::map* outputs, - const vm::Schema* schema) + const codec::Schema* schema) : ctx_(ctx), outputs_(outputs), schema_(schema), @@ -530,7 +530,7 @@ absl::StatusOr BufNativeEncoderIRBuilder::GetOrBuildAppendMapFn auto bs = ctx_->CreateBranchNot(is_null, [&]() -> base::Status { auto row_ptr = BuildGetPtrOffset(sub_builder, i8_ptr, str_body_offset); CHECK_TRUE(row_ptr.ok(), common::kCodegenError, row_ptr.status().ToString()); - auto sz = map_builder.Encode(ctx_, map_ptr, row_ptr.value()); + auto sz = map_builder.Encode(ctx_, row_ptr.value(), map_ptr); CHECK_TRUE(sz.ok(), common::kCodegenError, sz.status().ToString()); sub_builder->CreateStore(sz.value(), encode_sz_alloca); return {}; diff --git a/hybridse/src/codegen/buf_ir_builder.h b/hybridse/src/codegen/buf_ir_builder.h index 52d5d83385c..0ec9e664baf 100644 --- a/hybridse/src/codegen/buf_ir_builder.h +++ b/hybridse/src/codegen/buf_ir_builder.h @@ -25,7 +25,6 @@ #include "codegen/row_ir_builder.h" #include "codegen/scope_var.h" #include "codegen/variable_ir_builder.h" -#include "vm/catalog.h" namespace hybridse { namespace codegen { @@ -33,7 +32,7 @@ namespace codegen { class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { public: BufNativeEncoderIRBuilder(CodeGenContextBase* ctx, const std::map* outputs, - const vm::Schema* schema); + const codec::Schema* schema); ~BufNativeEncoderIRBuilder() override; @@ -55,10 +54,6 @@ class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { ::llvm::Value* str_addr_space, ::llvm::Value* str_body_offset, uint32_t str_field_idx, ::llvm::Value** output); - // encode SQL map data type into row - base::Status AppendMapVal(const type::ColumnSchema& sc, llvm::Value* i8_ptr, uint32_t field_idx, - const NativeValue& val, llvm::Value* str_addr_space, llvm::Value* str_body_offset, - uint32_t str_field_idx, llvm::Value** next_str_body_offset); absl::StatusOr GetOrBuildAppendMapFn(const type::ColumnSchema& sc) const; base::Status AppendHeader(::llvm::Value* i8_ptr, ::llvm::Value* size, @@ -74,7 +69,7 @@ class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { private: CodeGenContextBase* ctx_; const std::map* outputs_; - const vm::Schema* schema_; + const codec::Schema* schema_; uint32_t str_field_start_offset_; // n = offset_vec_[i] is // schema_[i] is base type (except string): col encode offset in row diff --git a/hybridse/src/codegen/insert_row_builder.cc b/hybridse/src/codegen/insert_row_builder.cc new file mode 100644 index 00000000000..c52eec6a1d8 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder.cc @@ -0,0 +1,149 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/insert_row_builder.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "base/fe_status.h" +#include "codegen/buf_ir_builder.h" +#include "codegen/context.h" +#include "codegen/expr_ir_builder.h" +#include "node/node_manager.h" +#include "passes/resolve_fn_and_attrs.h" +#include "udf/default_udf_library.h" +#include "vm/engine.h" +#include "vm/jit_wrapper.h" + +namespace hybridse { +namespace codegen { + +InsertRowBuilder::InsertRowBuilder(const codec::Schema* schema) : schema_(schema) {} + +absl::Status InsertRowBuilder::Init() { + ::hybridse::vm::Engine::InitializeGlobalLLVM(); + + jit_ = std::unique_ptr(vm::HybridSeJitWrapper::Create()); + if (!jit_->Init()) { + jit_ = nullptr; + return absl::InternalError("fail to init jit"); + } + if (!vm::HybridSeJitWrapper::InitJitSymbols(jit_.get())) { + jit_ = nullptr; + return absl::InternalError("fail to init jit symbols"); + } + return absl::OkStatus(); +} + +absl::StatusOr> InsertRowBuilder::ComputeRow(const node::ExprListNode* values) { + EnsureInitialized(); + return ComputeRow(values->children_); +} + +absl::StatusOr> InsertRowBuilder::ComputeRow(absl::Span values) { + EnsureInitialized(); + + std::unique_ptr llvm_ctx = llvm::make_unique(); + std::unique_ptr llvm_module = llvm::make_unique("insert_row_builder", *llvm_ctx); + vm::SchemasContext empty_sc; + node::NodeManager nm; + codec::Schema empty_param_types; + CodeGenContext dump_ctx(llvm_module.get(), &empty_sc, &empty_param_types, &nm); + + auto library = udf::DefaultUdfLibrary::get(); + node::ExprAnalysisContext expr_ctx(&nm, library, &empty_sc, &empty_param_types); + passes::ResolveFnAndAttrs resolver(&expr_ctx); + + std::vector transformed; + for (auto& expr : values) { + node::ExprNode* out = nullptr; + CHECK_STATUS_TO_ABSL(resolver.VisitExpr(expr, &out)); + transformed.push_back(out); + } + + std::string fn_name = absl::StrCat("gen_insert_row_", fn_counter_++); + auto fs = BuildFn(&dump_ctx, fn_name, transformed); + CHECK_ABSL_STATUSOR(fs); + + llvm::Function* fn = fs.value(); + + if (!jit_->OptModule(llvm_module.get())) { + return absl::InternalError("fail to optimize module"); + } + + if (!jit_->AddModule(std::move(llvm_module), std::move(llvm_ctx))) { + return absl::InternalError("add llvm module failed"); + } + + auto c_fn = jit_->FindFunction(fn->getName()); + void (*encode)(int8_t**) = reinterpret_cast(const_cast(c_fn)); + + int8_t* insert_row = nullptr; + encode(&insert_row); + + auto managed_row = std::shared_ptr(insert_row, std::free); + + return managed_row; +} + +absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, + absl::Span values) { + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + if (fn == nullptr) { + auto builder = ctx->GetBuilder(); + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getVoidTy(), + { + builder->getInt8PtrTy()->getPointerTo(), + }, + false); + + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* row_ptr_ptr = fn->arg_begin(); + + ExprIRBuilder expr_builder(ctx); + + std::map columns; + for (uint32_t i = 0; i < values.size(); ++i) { + auto expr = values[i]; + + NativeValue out; + auto s = expr_builder.Build(expr, &out); + CHECK_STATUS_TO_ABSL(s); + + columns[i] = out; + } + + BufNativeEncoderIRBuilder encode_builder(ctx, &columns, schema_); + CHECK_STATUS_TO_ABSL(encode_builder.Init()); + + encode_builder.BuildEncode(row_ptr_ptr); + + builder->CreateRetVoid(); + } + + return fn; +} + +// build the function that transform a single insert row values into encoded row +absl::StatusOr InsertRowBuilder::BuildEncodeFn() { return absl::OkStatus(); } +} // namespace codegen +} // namespace hybridse diff --git a/hybridse/src/codegen/insert_row_builder.h b/hybridse/src/codegen/insert_row_builder.h new file mode 100644 index 00000000000..83e8c1c2126 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder.h @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ +#define HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ + +#include + +#include "absl/status/statusor.h" +#include "codec/fe_row_codec.h" +#include "codegen/context.h" +#include "llvm/IR/Function.h" +#include "node/sql_node.h" +#include "vm/jit_wrapper.h" + +namespace hybridse { +namespace codegen { + +class InsertRowBuilder { + public: + explicit InsertRowBuilder(const codec::Schema* schema); + + absl::Status Init(); + + // compute the encoded row result for insert statement's single values expression list + // + // currently, expressions in insert values do not expect external source, so unsupported expressions + // will simply fail on resolving. + absl::StatusOr> ComputeRow(absl::Span values); + + absl::StatusOr> ComputeRow(const node::ExprListNode* values); + + private: + void EnsureInitialized() { assert(jit_ && "InsertRowBuilder not initialized"); } + + // build the function the will output the row from single insert values + // + // the function is just equivalent to C: `void fn(int8_t**)`. + // BuildFn returns different function with different name on every invocation + absl::StatusOr BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, + absl::Span); + + // build the function that transform a single insert row values into encoded row + absl::StatusOr BuildEncodeFn(); + + // CodeGenContextBase* ctx_; + const codec::Schema* schema_; + std::atomic fn_counter_ = 0; + + std::unique_ptr jit_; +}; +} // namespace codegen +} // namespace hybridse +#endif // HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ diff --git a/hybridse/src/codegen/insert_row_builder_test.cc b/hybridse/src/codegen/insert_row_builder_test.cc new file mode 100644 index 00000000000..4924c175957 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder_test.cc @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/insert_row_builder.h" + +#include + +#include "gtest/gtest.h" +#include "node/sql_node.h" +#include "plan/plan_api.h" +#include "vm/sql_ctx.h" + +namespace hybridse { +namespace codegen { + +class InsertRowBuilderTest : public ::testing::Test {}; + +TEST_F(InsertRowBuilderTest, encode) { + std::string sql = "insert into t1 values (1, map (1, '12'))"; + vm::SqlContext ctx; + ctx.sql = sql; + auto s = plan::PlanAPI::CreatePlanTreeFromScript(&ctx); + ASSERT_TRUE(s.isOK()) << s; + + auto* exprlist = dynamic_cast(ctx.logical_plan.front())->GetInsertNode()->values_[0]; + + codec::Schema sc; + { + auto col1 = sc.Add(); + col1->mutable_schema()->set_base_type(type::kInt32); + col1->set_type(type::kInt32); + } + + { + auto col = sc.Add(); + auto map_ty = col->mutable_schema()->mutable_map_type(); + map_ty->mutable_key_type()->set_base_type(type::kInt32); + map_ty->mutable_value_type()->set_base_type(type::kVarchar); + } + + InsertRowBuilder builder(&sc); + { + auto s = builder.Init(); + ASSERT_TRUE(s.ok()) << s; + } + + auto as = builder.ComputeRow(dynamic_cast(exprlist)); + ASSERT_TRUE(as.ok()) << as.status(); + + ASSERT_TRUE(as.value() != nullptr); +} +} // namespace codegen +} // namespace hybridse +// +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/hybridse/src/codegen/ir_base_builder_test.h b/hybridse/src/codegen/ir_base_builder_test.h index af29e4fd56c..494cfdb0818 100644 --- a/hybridse/src/codegen/ir_base_builder_test.h +++ b/hybridse/src/codegen/ir_base_builder_test.h @@ -360,7 +360,11 @@ void ModuleFunctionBuilderWithFullInfo::ExpandApplyArg( if (TypeIRBuilder::IsStructPtr(expect_ty)) { auto struct_builder = StructTypeIRBuilder::CreateStructTypeIRBuilder(function->getEntryBlock().getModule(), expect_ty); - struct_builder->CreateDefault(&function->getEntryBlock(), + if (!struct_builder.ok()) { + LOG(WARNING) << struct_builder.status(); + return; + } + struct_builder.value()->CreateDefault(&function->getEntryBlock(), &alloca); arg = builder.CreateSelect( is_null, alloca, builder.CreatePointerCast(arg, expect_ty)); diff --git a/hybridse/src/codegen/map_ir_builder.cc b/hybridse/src/codegen/map_ir_builder.cc index e54543040fd..27e6944c102 100644 --- a/hybridse/src/codegen/map_ir_builder.cc +++ b/hybridse/src/codegen/map_ir_builder.cc @@ -197,7 +197,7 @@ absl::StatusOr MapIRBuilder::ExtractElement(CodeGenContextBase* ctx ctx->GetBuilder()->getInt1Ty()->getPointerTo() // output is null ptr }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); @@ -362,6 +362,53 @@ absl::StatusOr MapIRBuilder::MapKeys(CodeGenContextBase* ctx, const return out; } +absl::StatusOr MapIRBuilder::BuildEncodeByteSizeFn(CodeGenContextBase* ctx) const { + std::string fn_name = absl::StrCat("calc_encode_map_sz_", GetIRTypeName(struct_type_)); + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + auto builder = ctx->GetBuilder(); + if (fn == nullptr) { + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getInt32Ty(), // return size + { + struct_type_->getPointerTo(), + }, + false); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* raw = fn->arg_begin(); + // map_size + [key_ele_sz * map_size] + [val_ele_sz * map_sz] + [sizeof(bool) * map_size] + llvm::Value* final_size = CodecSizeForPrimitive(builder, builder->getInt32Ty()); + + auto elements = Load(ctx, raw); + if (!elements.ok()) { + return elements.status(); + } + + if (elements->size() != FIELDS_CNT) { + return absl::FailedPreconditionError( + absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); + } + auto& elements_vec = elements.value(); + auto& map_size = elements_vec[0]; + auto& key_vec = elements_vec[1]; + auto& value_vec = elements_vec[2]; + auto& value_null_vec = elements_vec[3]; + + auto keys_sz = CalEncodeSizeForArray(ctx, key_vec, map_size); + CHECK_ABSL_STATUSOR(keys_sz); + auto values_sz = CalEncodeSizeForArray(ctx, value_vec, map_size); + CHECK_ABSL_STATUSOR(values_sz); + auto values_null_sz = CalEncodeSizeForArray(ctx, value_null_vec, map_size); + CHECK_ABSL_STATUSOR(values_null_sz); + + builder->CreateRet(builder->CreateAdd( + final_size, + builder->CreateAdd(keys_sz.value(), builder->CreateAdd(values_sz.value(), values_null_sz.value())))); + } + + return fn; +} + absl::StatusOr MapIRBuilder::CalEncodeByteSize(CodeGenContextBase* ctx, llvm::Value* raw) const { auto builder = ctx->GetBuilder(); if (!raw->getType()->isPointerTy() || raw->getType()->getPointerElementType() != struct_type_) { @@ -370,33 +417,11 @@ absl::StatusOr MapIRBuilder::CalEncodeByteSize(CodeGenContextBase* GetLlvmObjectString(raw->getType()))); } - // map_size + [key_ele_sz * map_size] + [val_ele_sz * map_sz] + [sizeof(bool) * map_size] - llvm::Value* final_size = CodecSizeForPrimitive(builder, builder->getInt32Ty()); + auto fns = BuildEncodeByteSizeFn(ctx); - auto elements = Load(ctx, raw); - if (!elements.ok()) { - return elements.status(); - } + CHECK_ABSL_STATUSOR(fns); - if (elements->size() != FIELDS_CNT) { - return absl::FailedPreconditionError( - absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); - } - auto& elements_vec = elements.value(); - auto& map_size = elements_vec[0]; - auto& key_vec = elements_vec[1]; - auto& value_vec = elements_vec[2]; - auto& value_null_vec = elements_vec[3]; - - auto keys_sz = CalEncodeSizeForArray(ctx, key_vec, map_size); - CHECK_ABSL_STATUSOR(keys_sz); - auto values_sz = CalEncodeSizeForArray(ctx, value_vec, map_size); - CHECK_ABSL_STATUSOR(values_sz); - auto values_null_sz = CalEncodeSizeForArray(ctx, value_null_vec, map_size); - CHECK_ABSL_STATUSOR(values_null_sz); - - return builder->CreateAdd( - final_size, builder->CreateAdd(keys_sz.value(), builder->CreateAdd(values_sz.value(), values_null_sz.value()))); + return builder->CreateCall(fns.value(), {raw}); } absl::StatusOr MapIRBuilder::CalEncodeSizeForArray(CodeGenContextBase* ctx, llvm::Value* arr_ptr, @@ -429,7 +454,7 @@ absl::StatusOr MapIRBuilder::CalEncodeSizeForArray(CodeGenContextB builder->getInt32Ty() // arr size }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); auto sub_builder = ctx->GetBuilder(); @@ -508,10 +533,82 @@ absl::StatusOr MapIRBuilder::TypeEncodeByteSize(CodeGenContextBase return absl::UnimplementedError(absl::StrCat("encode type ", GetLlvmObjectString(ele_type))); } -absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm::Value* map_ptr, - llvm::Value* row_ptr) const { +absl::StatusOr MapIRBuilder::BuildEncodeFn(CodeGenContextBase* ctx) const { + std::string fn_name = absl::StrCat("encode_map_", GetIRTypeName(struct_type_)); + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + + auto builder = ctx->GetBuilder(); + if (fn == nullptr) { + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getInt32Ty(), // encoded byte size + { + builder->getInt8PtrTy(), // row ptr + struct_type_->getPointerTo(), // map ptr + }, + false); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* row_ptr = fn->arg_begin(); + llvm::Value* map_ptr = fn->arg_begin() + 1; + llvm::Value* written = builder->getInt32(0); + + auto elements = Load(ctx, map_ptr); + if (!elements.ok()) { + return elements.status(); + } + + if (elements->size() != FIELDS_CNT) { + return absl::FailedPreconditionError( + absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); + } + + auto& elements_vec = elements.value(); + auto& map_size = elements_vec[0]; + auto& key_vec = elements_vec[1]; + auto& value_vec = elements_vec[2]; + auto& value_null_vec = elements_vec[3]; + + // *(int32*) row_ptr = map_size + { + CHECK_ABSL_STATUS(BuildStoreOffset(builder, row_ptr, builder->getInt32(0), map_size)); + + written = builder->CreateAdd(written, builder->getInt32(4)); + } + { + // *(key_type[map_size]) (row_ptr + 4) = key_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), key_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + { + // *(value_type[map_size]) (row_ptr + ?) = value_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + { + // *(bool[map_size]) (row_ptr + ?) = value_null_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + // TODO(someone): alignment issue, bitwise operation for better performance ? + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_null_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + + builder->CreateRet(written); + } + return fn; +} + +absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm::Value* row_ptr, + llvm::Value* map_ptr) const { auto builder = ctx->GetBuilder(); - llvm::Value* written = builder->getInt32(0); if (row_ptr->getType() != builder->getInt8Ty()->getPointerTo()) { return absl::FailedPreconditionError( @@ -525,55 +622,10 @@ absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm: GetLlvmObjectString(map_ptr->getType()->getPointerElementType()))); } - auto elements = Load(ctx, map_ptr); - if (!elements.ok()) { - return elements.status(); - } - - if (elements->size() != FIELDS_CNT) { - return absl::FailedPreconditionError( - absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); - } - - auto& elements_vec = elements.value(); - auto& map_size = elements_vec[0]; - auto& key_vec = elements_vec[1]; - auto& value_vec = elements_vec[2]; - auto& value_null_vec = elements_vec[3]; - - // *(int32*) row_ptr = map_size - { - CHECK_ABSL_STATUS(BuildStoreOffset(builder, row_ptr, builder->getInt32(0), map_size)); - - written = builder->CreateAdd(written, builder->getInt32(4)); - } - { - // *(key_type[map_size]) (row_ptr + 4) = key_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), key_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } - { - // *(value_type[map_size]) (row_ptr + ?) = value_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } - { - // *(bool[map_size]) (row_ptr + ?) = value_null_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - // TODO(someone): alignment issue, bitwise operation for better performance ? - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_null_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } + auto fns = BuildEncodeFn(ctx); + CHECK_ABSL_STATUSOR(fns); - return written; + return builder->CreateCall(fns.value(), {row_ptr, map_ptr}); } absl::StatusOr MapIRBuilder::EncodeArray(CodeGenContextBase* ctx_, llvm::Value* row_ptr, @@ -737,7 +789,7 @@ absl::StatusOr MapIRBuilder::DecodeArrayValue(CodeGenContextBase* }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); auto* sub_builder = ctx->GetBuilder(); @@ -855,7 +907,7 @@ absl::StatusOr MapIRBuilder::GetOrBuildEncodeArrFunction(CodeGe builder->getInt32Ty(), {builder->getInt8Ty()->getPointerTo(), ele_type->getPointerTo(), builder->getInt32Ty()}, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); // enter function FunctionScopeGuard fg(fn, ctx); diff --git a/hybridse/src/codegen/map_ir_builder.h b/hybridse/src/codegen/map_ir_builder.h index 1a7413f3ad5..f7063bcde45 100644 --- a/hybridse/src/codegen/map_ir_builder.h +++ b/hybridse/src/codegen/map_ir_builder.h @@ -17,8 +17,6 @@ #ifndef HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ #define HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ -#include - #include "codegen/struct_ir_builder.h" namespace hybridse { @@ -43,9 +41,13 @@ class MapIRBuilder final : public StructTypeIRBuilder { absl::StatusOr CalEncodeByteSize(CodeGenContextBase* ctx, llvm::Value*) const; + absl::StatusOr BuildEncodeByteSizeFn(CodeGenContextBase* ctx) const; + // Encode the `map_ptr` into `row_ptr`, returns byte size written on success // `row_ptr` is ensured to have enough space - absl::StatusOr Encode(CodeGenContextBase*, llvm::Value* map_ptr, llvm::Value* row_ptr) const; + absl::StatusOr Encode(CodeGenContextBase*, llvm::Value* row_ptr, llvm::Value* map_ptr) const; + + absl::StatusOr BuildEncodeFn(CodeGenContextBase*) const; // Decode the stored map value at address row_ptr absl::StatusOr Decode(CodeGenContextBase*, llvm::Value* row_ptr) const; diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index d616522931a..0d08e89aefb 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -31,31 +31,34 @@ StructTypeIRBuilder::StructTypeIRBuilder(::llvm::Module* m) StructTypeIRBuilder::~StructTypeIRBuilder() {} bool StructTypeIRBuilder::StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) { - StructTypeIRBuilder* struct_builder = CreateStructTypeIRBuilder(block->getModule(), src->getType()); - bool ok = struct_builder->CopyFrom(block, src, dist); - delete struct_builder; - return ok; + auto struct_builder = CreateStructTypeIRBuilder(block->getModule(), src->getType()); + if (struct_builder.ok()) { + return struct_builder.value()->CopyFrom(block, src, dist); + } + return false; } -StructTypeIRBuilder* StructTypeIRBuilder::CreateStructTypeIRBuilder(::llvm::Module* m, ::llvm::Type* type) { +absl::StatusOr> StructTypeIRBuilder::CreateStructTypeIRBuilder( + ::llvm::Module* m, ::llvm::Type* type) { node::DataType base_type; if (!GetBaseType(type, &base_type)) { - return nullptr; + return absl::UnimplementedError( + absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type))); } switch (base_type) { case node::kTimestamp: - return new TimestampIRBuilder(m); + return std::make_unique(m); case node::kDate: - return new DateIRBuilder(m); + return std::make_unique(m); case node::kVarchar: - return new StringIRBuilder(m); + return std::make_unique(m); default: { - LOG(WARNING) << "fail to create struct type ir builder for " << DataTypeName(base_type); - return nullptr; + break; } } - return nullptr; + return absl::UnimplementedError( + absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type))); } absl::StatusOr StructTypeIRBuilder::CreateNull(::llvm::BasicBlock* block) { diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index 9e5437f5158..4c09e488ce9 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -19,6 +19,7 @@ #include #include +#include #include "absl/status/statusor.h" #include "base/fe_status.h" @@ -33,7 +34,8 @@ class StructTypeIRBuilder : public TypeIRBuilder { explicit StructTypeIRBuilder(::llvm::Module*); ~StructTypeIRBuilder(); - static StructTypeIRBuilder* CreateStructTypeIRBuilder(::llvm::Module*, ::llvm::Type*); + static absl::StatusOr> CreateStructTypeIRBuilder(::llvm::Module*, + ::llvm::Type*); static bool StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist); virtual bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) = 0; diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index f181781d095..a6656dfc0f3 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -710,7 +710,7 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees, break; } case node::kInsertStmt: { - CHECK_TRUE(is_batch_mode_, common::kPlanError, "Non-support INSERT Op in online serving"); + // CHECK_TRUE(is_batch_mode_, common::kPlanError, "Non-support INSERT Op in online serving"); node::PlanNode *insert_plan = nullptr; CHECK_STATUS(CreateInsertPlan(parser_tree, &insert_plan)) plan_trees.push_back(insert_plan); diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 0261c673423..0427b5c6ba6 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -1996,11 +1996,6 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod CHECK_TRUE(nullptr != row, common::kSqlAstError, "Un-support insert statement with null row") node::ExprListNode* row_values; CHECK_STATUS(ConvertExprNodeList(row->values(), node_manager, &row_values)) - for (auto expr : row_values->children_) { - CHECK_TRUE(nullptr != expr && - (node::kExprPrimary == expr->GetExprType() || node::kExprParameter == expr->GetExprType()), - common::kSqlAstError, "Un-support insert statement with un-const value") - } rows->AddChild(row_values); } diff --git a/hybridse/src/planv2/ast_node_converter_test.cc b/hybridse/src/planv2/ast_node_converter_test.cc index 9798e69bd4f..178117728ad 100644 --- a/hybridse/src/planv2/ast_node_converter_test.cc +++ b/hybridse/src/planv2/ast_node_converter_test.cc @@ -868,12 +868,6 @@ TEST_F(ASTNodeConverterTest, ConvertInsertStmtFailTest) { )sql"; expect_converted(sql, common::kSqlAstError, "Un-support Named Parameter Expression a"); } - { - const std::string sql = R"sql( - INSERT into t1 values (1, 2L, aaa) - )sql"; - expect_converted(sql, common::kSqlAstError, "Un-support insert statement with un-const value"); - } } TEST_F(ASTNodeConverterTest, ConvertStmtFailTest) { node::NodeManager node_manager; diff --git a/hybridse/src/planv2/planner_v2_test.cc b/hybridse/src/planv2/planner_v2_test.cc index b5012bcce11..2b83865b1bb 100644 --- a/hybridse/src/planv2/planner_v2_test.cc +++ b/hybridse/src/planv2/planner_v2_test.cc @@ -16,16 +16,13 @@ #include "planv2/planner_v2.h" -#include #include #include #include "case/sql_case.h" #include "gtest/gtest.h" #include "plan/plan_api.h" -#include "zetasql/parser/parser.h" -#include "zetasql/public/error_helpers.h" -#include "zetasql/public/error_location.pb.h" + namespace hybridse { namespace plan { diff --git a/hybridse/src/sdk/base_impl.cc b/hybridse/src/sdk/base_impl.cc index fe34f9cd1a8..153ceef8eef 100644 --- a/hybridse/src/sdk/base_impl.cc +++ b/hybridse/src/sdk/base_impl.cc @@ -24,7 +24,7 @@ namespace sdk { static const std::string EMPTY_STR; // NOLINT -SchemaImpl::SchemaImpl(const vm::Schema& schema) : schema_(schema) {} +SchemaImpl::SchemaImpl(const codec::Schema& schema) : schema_(schema) {} SchemaImpl::~SchemaImpl() {} diff --git a/hybridse/src/vm/engine.cc b/hybridse/src/vm/engine.cc index 0865655f3c1..ac9ee9dcaaf 100644 --- a/hybridse/src/vm/engine.cc +++ b/hybridse/src/vm/engine.cc @@ -15,19 +15,22 @@ */ #include "vm/engine.h" + #include #include #include + +#include "absl/time/clock.h" #include "boost/none.hpp" #include "codec/fe_row_codec.h" #include "gflags/gflags.h" #include "llvm-c/Target.h" #include "udf/default_udf_library.h" +#include "vm/internal/node_helper.h" #include "vm/local_tablet_handler.h" #include "vm/mem_catalog.h" -#include "vm/sql_compiler.h" -#include "vm/internal/node_helper.h" #include "vm/runner_ctx.h" +#include "vm/sql_compiler.h" DECLARE_bool(enable_spark_unsaferow_format); @@ -52,10 +55,15 @@ Engine::Engine(const std::shared_ptr& catalog) : cl_(catalog), options_ Engine::Engine(const std::shared_ptr& catalog, const EngineOptions& options) : cl_(catalog), options_(options), mu_(), lru_cache_() {} Engine::~Engine() {} + void Engine::InitializeGlobalLLVM() { + // not thread safe, but is generally fine to call multiple times if (LLVM_IS_INITIALIZED) return; + + absl::Time begin = absl::Now(); LLVMInitializeNativeTarget(); LLVMInitializeNativeAsmPrinter(); + LOG(INFO) << "initialize llvm native target and asm printer, takes " << absl::Now() - begin; LLVM_IS_INITIALIZED = true; } diff --git a/hybridse/src/vm/jit.h b/hybridse/src/vm/jit.h index 24bb9a74856..7af5f17ac0d 100644 --- a/hybridse/src/vm/jit.h +++ b/hybridse/src/vm/jit.h @@ -94,13 +94,15 @@ class HybridSeLlvmJitWrapper : public HybridSeJitWrapper { bool OptModule(::llvm::Module* module) override; - bool AddModule(std::unique_ptr module, - std::unique_ptr llvm_ctx) override; + bool AddModule(std::unique_ptr module, std::unique_ptr llvm_ctx) override; bool AddExternalFunction(const std::string& name, void* addr) override; - hybridse::vm::RawPtrHandle FindFunction( - const std::string& funcname) override; + hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) override; + + // llvm::Module* GetModule() { + // } + // llvm::LLVMContext* GetLlvmContext(); private: std::unique_ptr jit_; diff --git a/hybridse/src/vm/jit_wrapper.h b/hybridse/src/vm/jit_wrapper.h index 458cb28272d..b0bbb70c6ec 100644 --- a/hybridse/src/vm/jit_wrapper.h +++ b/hybridse/src/vm/jit_wrapper.h @@ -45,8 +45,7 @@ class HybridSeJitWrapper { bool AddModuleFromBuffer(const base::RawBuffer&); - virtual hybridse::vm::RawPtrHandle FindFunction( - const std::string& funcname) = 0; + virtual hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) = 0; static HybridSeJitWrapper* Create(const JitOptions& jit_options); static HybridSeJitWrapper* Create(); diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index 2cfaea7de15..7741b14a1f4 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -3440,6 +3440,53 @@ TEST_P(DBSDKTest, CreateIfNotExists) { ASSERT_TRUE(cs->GetNsClient()->DropDatabase("test2", msg)) << msg; } +TEST_P(DBSDKTest, MapTypeTable) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + absl::BitGen gen; + auto db = absl::StrCat("db_", absl::Uniform(gen, 0, std::numeric_limits::max())); + auto table = absl::StrCat("tb_", absl::Uniform(gen, 0, std::numeric_limits::max())); + + ProcessSQLs(sr, { + "set session execute_mode = 'online'", + absl::StrCat("create database ", db), + absl::StrCat("use ", db), + absl::Substitute("create table $0 (id string, val map)", table), + absl::Substitute("insert into $0 values ('1', map(12, '23')) ", table), + absl::Substitute("insert into $0 values ('4', map(99, '44')) ", table), + }); + absl::Cleanup clean = [&]() { + ProcessSQLs(sr, { + absl::Substitute("drop table $0", table), + absl::Substitute("drop database $0", db), + }); + }; + + // query + hybridse::sdk::Status status; + auto rs = sr->ExecuteSQL(absl::Substitute("select id, val[12] as ele from $0", table), &status); + ASSERT_TRUE(status.IsOK()) << status.ToString(); + ASSERT_EQ(rs->Size(), 2); + + while (rs->Next()) { + // result is unordered + std::string id; + ASSERT_TRUE(rs->GetAsString(0, id)); + std::string ele; + ASSERT_TRUE(rs->GetAsString(1, ele)); + + if (id == "1") { + EXPECT_EQ(ele, "23"); + } else if (id == "4") { + EXPECT_EQ(ele, "NULL"); + EXPECT_TRUE(rs->IsNULL(1)); + } else { + ASSERT_FALSE(true) << "should not reach"; + } + } +} + TEST_P(DBSDKTest, ShowComponents) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/codec/codec.cc b/src/codec/codec.cc index 8d5e24bc8c8..858acfc374a 100644 --- a/src/codec/codec.cc +++ b/src/codec/codec.cc @@ -1152,5 +1152,15 @@ bool RowProject::Project(const int8_t* row_ptr, uint32_t size, int8_t** output_p return true; } +bool ColumnSupportLegacyCodec(const openmldb::common::ColumnDesc& col_desc) { + auto dt = col_desc.data_type(); + if (col_desc.has_schema()) { + dt = col_desc.schema().type(); + } + + return (dt >= openmldb::type::kBool && dt <= openmldb::type::kTimestamp) || dt == openmldb::type::kVarchar || + dt == openmldb::type::kString; +} + } // namespace codec } // namespace openmldb diff --git a/src/codec/codec.h b/src/codec/codec.h index 681c05ae2aa..b84289a2ce8 100644 --- a/src/codec/codec.h +++ b/src/codec/codec.h @@ -192,6 +192,8 @@ class RowView { std::vector offset_vec_; }; +bool ColumnSupportLegacyCodec(const openmldb::common::ColumnDesc&); + namespace v1 { inline int8_t GetAddrSpace(uint32_t size) { diff --git a/src/proto/common.proto b/src/proto/common.proto index 1bc539766c3..8241e646f34 100755 --- a/src/proto/common.proto +++ b/src/proto/common.proto @@ -41,12 +41,21 @@ message DbTableNamePair { required string db_name = 1; required string table_name = 2; } +message TableColumnSchema { + optional string name = 1; + optional openmldb.type.DataType type = 2; + repeated TableColumnSchema type_fields = 3; + optional bool not_null = 4 [default = false]; +} + message ColumnDesc { required string name = 1; optional openmldb.type.DataType data_type = 2; optional bool not_null = 3 [default = false]; optional bool is_constant = 4 [default = false]; optional string default_value = 5; + // replacing ColumnDesc::data_type and ColumnDesc::not_null + optional TableColumnSchema schema = 6; } message TTLSt { diff --git a/src/proto/type.proto b/src/proto/type.proto index 5bb6b7faa67..83b80631ca1 100755 --- a/src/proto/type.proto +++ b/src/proto/type.proto @@ -34,6 +34,8 @@ enum DataType { // reserve 9, 10, 11, 12 kVarchar = 13; kString = 14; + kArray = 15; + kMap = 16; } enum IndexType { @@ -77,4 +79,4 @@ enum ProcedureType { enum NotifyType { kTable = 1; kGlobalVar = 2; -} \ No newline at end of file +} diff --git a/src/schema/schema_adapter.cc b/src/schema/schema_adapter.cc index f93172d85ac..0fdbefcf5c5 100644 --- a/src/schema/schema_adapter.cc +++ b/src/schema/schema_adapter.cc @@ -20,12 +20,15 @@ #include #include #include +#include "absl/status/status.h" #include "glog/logging.h" +#include "proto/fe_type.pb.h" +#include "proto/type.pb.h" namespace openmldb { namespace schema { -bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_schema, +bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::codec::Schema& sql_schema, const ::hybridse::vm::IndexList& index, PBSchema* schema_output, PBIndex* index_output) { if (nullptr == schema_output || nullptr == index_output) { @@ -56,8 +59,8 @@ bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_sche return true; } -bool SchemaAdapter::SubSchema(const ::hybridse::vm::Schema* schema, - const ::google::protobuf::RepeatedField& projection, hybridse::vm::Schema* output) { +bool SchemaAdapter::SubSchema(const ::hybridse::codec::Schema* schema, + const ::google::protobuf::RepeatedField& projection, hybridse::codec::Schema* output) { if (output == nullptr) { LOG(WARNING) << "output ptr is null"; return false; @@ -70,12 +73,12 @@ bool SchemaAdapter::SubSchema(const ::hybridse::vm::Schema* schema, return true; } std::shared_ptr<::hybridse::sdk::Schema> SchemaAdapter::ConvertSchema(const PBSchema& schema) { - ::hybridse::vm::Schema vm_schema; + ::hybridse::codec::Schema vm_schema; ConvertSchema(schema, &vm_schema); return std::make_shared<::hybridse::sdk::SchemaImpl>(vm_schema); } -bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema* output) { +bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::codec::Schema* output) { if (output == nullptr) { LOG(WARNING) << "output ptr is null"; return false; @@ -85,23 +88,18 @@ bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema return false; } for (int32_t i = 0; i < schema.size(); i++) { - const common::ColumnDesc& column = schema.Get(i); - ::hybridse::type::ColumnDef* new_column = output->Add(); - new_column->set_name(column.name()); - new_column->set_is_not_null(column.not_null()); - new_column->set_is_constant(column.is_constant()); - ::hybridse::type::Type type; - if (!ConvertType(column.data_type(), &type)) { - LOG(WARNING) << "type " << ::openmldb::type::DataType_Name(column.data_type()) - << " is not supported"; + const common::ColumnDesc& table_column = schema.Get(i); + ::hybridse::type::ColumnDef* sql_column = output->Add(); + auto s = ConvertColumn(table_column, sql_column); + if (!s.ok()) { + LOG(WARNING) << s.ToString(); return false; } - new_column->set_type(type); } return true; } -bool SchemaAdapter::ConvertSchema(const ::hybridse::vm::Schema& hybridse_schema, PBSchema* schema) { +bool SchemaAdapter::ConvertSchema(const ::hybridse::codec::Schema& hybridse_schema, PBSchema* schema) { if (schema == nullptr) { LOG(WARNING) << "schema is null"; return false; @@ -155,6 +153,62 @@ bool SchemaAdapter::ConvertType(hybridse::node::DataType hybridse_type, openmldb return true; } +absl::Status SchemaAdapter::ConvertType(const hybridse::node::ColumnSchemaNode* sc, common::TableColumnSchema* tbs) { + if (sc == nullptr) { + return absl::InvalidArgumentError("paramter null"); + } + switch (sc->type()) { + case hybridse::node::kBool: + tbs->set_type(openmldb::type::kBool); + break; + case hybridse::node::kInt16: + tbs->set_type(openmldb::type::kSmallInt); + break; + case hybridse::node::kInt32: + tbs->set_type(openmldb::type::kInt); + break; + case hybridse::node::kInt64: + tbs->set_type(openmldb::type::kBigInt); + break; + case hybridse::node::kFloat: + tbs->set_type(openmldb::type::kFloat); + break; + case hybridse::node::kDouble: + tbs->set_type(openmldb::type::kDouble); + break; + case hybridse::node::kDate: + tbs->set_type(openmldb::type::kDate); + break; + case hybridse::node::kTimestamp: + tbs->set_type(openmldb::type::kTimestamp); + break; + case hybridse::node::kVarchar: + tbs->set_type(openmldb::type::kVarchar); + break; + case hybridse::node::kArray: { + tbs->set_type(openmldb::type::kArray); + break; + } + case hybridse::node::kMap: { + tbs->set_type(openmldb::type::kMap); + break; + } + default: + return absl::UnimplementedError(absl::StrCat("unsupported type: ", sc->DebugString())); + } + + for (auto& field_type : sc->generics()) { + auto* field = tbs->add_type_fields(); + auto s = ConvertType(field_type, field); + if (!s.ok()) { + return s; + } + } + + tbs->set_not_null(sc->not_null()); + return absl::OkStatus(); +} + bool SchemaAdapter::ConvertType(openmldb::type::DataType type, hybridse::node::DataType* hybridse_type) { if (hybridse_type == nullptr) { return false; @@ -358,6 +412,99 @@ bool SchemaAdapter::ConvertColumn(const hybridse::type::ColumnDef& sql_column, o return true; } +absl::Status SchemaAdapter::ConvertColumn(const openmldb::common::ColumnDesc& column, + hybridse::type::ColumnDef* sql_column) { + if (column.has_schema()) { + // new schema field + auto s = ConvertSchema(column.schema(), sql_column->mutable_schema()); + if (!s.ok()) { + return s; + } + } else { + // fallback use data_type and not_null + ::hybridse::type::Type ty; + if (!ConvertType(column.data_type(), &ty)) { + return absl::InternalError(absl::StrCat("failed to convert type: ", column.DebugString())); + } + auto sc = sql_column->mutable_schema(); + sc->set_base_type(ty); + sc->set_is_not_null(column.not_null()); + } + + if (sql_column->schema().has_base_type()) { + sql_column->set_type(sql_column->schema().base_type()); + } + sql_column->set_is_not_null(sql_column->schema().is_not_null()); + + sql_column->set_name(column.name()); + sql_column->set_is_constant(column.is_constant()); + return absl::OkStatus(); +} + +absl::Status SchemaAdapter::ConvertSchema(const openmldb::common::TableColumnSchema& ts, + hybridse::type::ColumnSchema* sc) { + switch (ts.type()) { + case openmldb::type::kBool: + sc->set_base_type(::hybridse::type::kBool); + break; + case openmldb::type::kSmallInt: + sc->set_base_type(::hybridse::type::kInt16); + break; + case openmldb::type::kInt: + sc->set_base_type(::hybridse::type::kInt32); + break; + case openmldb::type::kBigInt: + sc->set_base_type(::hybridse::type::kInt64); + break; + case openmldb::type::kFloat: + sc->set_base_type(::hybridse::type::kFloat); + break; + case openmldb::type::kDouble: + sc->set_base_type(::hybridse::type::kDouble); + break; + case openmldb::type::kDate: + sc->set_base_type(::hybridse::type::kDate); + break; + case openmldb::type::kTimestamp: + sc->set_base_type(::hybridse::type::kTimestamp); + break; + case openmldb::type::kVarchar: + case openmldb::type::kString: + sc->set_base_type(::hybridse::type::kVarchar); + break; + + case openmldb::type::kArray: { + auto arr_ty = sc->mutable_array_type(); + if (ts.type_fields_size() != 1) { + return absl::FailedPreconditionError( + absl::StrCat("array type requires type_fields size=1, got size=", ts.type_fields_size())); + } + auto s = ConvertSchema(ts.type_fields().Get(0), arr_ty->mutable_ele_type()); + if (!s.ok()) { + return s; + } + break; + } + case openmldb::type::kMap: { + auto map_ty = sc->mutable_map_type(); + if (ts.type_fields_size() != 2) { + return absl::FailedPreconditionError( + absl::StrCat("map type requires type_fields size=2, got size=", ts.type_fields_size())); + } + auto s = ConvertSchema(ts.type_fields().Get(0), map_ty->mutable_key_type()); + s.Update(ConvertSchema(ts.type_fields().Get(1), map_ty->mutable_value_type())); + if (!s.ok()) { + return s; + } + break; + } + } + + sc->set_is_not_null(ts.not_null()); + + return absl::OkStatus(); +} + std::map SchemaAdapter::GetColMap(const nameserver::TableInfo& table_info) { std::map col_map; for (const auto& col : table_info.column_desc()) { diff --git a/src/schema/schema_adapter.h b/src/schema/schema_adapter.h index e209b380d9e..d6433800699 100644 --- a/src/schema/schema_adapter.h +++ b/src/schema/schema_adapter.h @@ -28,28 +28,31 @@ #include "proto/tablet.pb.h" #include "schema/index_util.h" #include "vm/catalog.h" +#include "node/sql_node.h" namespace openmldb { namespace schema { class SchemaAdapter { public: - static bool ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_schema, + static bool ConvertSchemaAndIndex(const ::hybridse::codec::Schema& sql_schema, const ::hybridse::vm::IndexList& index, PBSchema* schema_output, PBIndex* index_output); - static bool SubSchema(const ::hybridse::vm::Schema* schema, + static bool SubSchema(const ::hybridse::codec::Schema* schema, const ::google::protobuf::RepeatedField& projection, - hybridse::vm::Schema* output); + hybridse::codec::Schema* output); - static bool ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema* output); + static bool ConvertSchema(const PBSchema& schema, ::hybridse::codec::Schema* output); static std::shared_ptr<::hybridse::sdk::Schema> ConvertSchema(const PBSchema& schema); - static bool ConvertSchema(const ::hybridse::vm::Schema& hybridse_schema, PBSchema* schema); + static bool ConvertSchema(const ::hybridse::codec::Schema& hybridse_schema, PBSchema* schema); static bool ConvertType(hybridse::node::DataType hybridse_type, openmldb::type::DataType* type); + static absl::Status ConvertType(const hybridse::node::ColumnSchemaNode* sc, common::TableColumnSchema* tbs); + static bool ConvertType(openmldb::type::DataType type, hybridse::node::DataType* hybridse_type); static bool ConvertType(hybridse::type::Type hybridse_type, openmldb::type::DataType* openmldb_type); @@ -70,6 +73,16 @@ class SchemaAdapter { private: static bool ConvertColumn(const hybridse::type::ColumnDef& sql_column, openmldb::common::ColumnDesc* column); + + // table column definition to SQL type. + // + // NOTE NOT ALL fields from table column are convertable to SQL type, be aware the difference between + // 'table_column_definition' and 'type' from parser. + // For example common::ColumnDesc::default_value does not have corresponding field in hybridse::type::ColumnDef. + static absl::Status ConvertColumn(const openmldb::common::ColumnDesc& column, hybridse::type::ColumnDef* sql_column) + ABSL_ATTRIBUTE_NONNULL(); + static absl::Status ConvertSchema(const openmldb::common::TableColumnSchema&, hybridse::type::ColumnSchema*) + ABSL_ATTRIBUTE_NONNULL(); }; } // namespace schema diff --git a/src/sdk/node_adapter.cc b/src/sdk/node_adapter.cc index 58a0b534b4e..8d8336fbaa3 100644 --- a/src/sdk/node_adapter.cc +++ b/src/sdk/node_adapter.cc @@ -314,16 +314,17 @@ bool NodeAdapter::TransformToTableDef(::hybridse::node::CreatePlanNode* create_n return false; } add_column_desc->set_name(column_def->GetColumnName()); - add_column_desc->set_not_null(column_def->GetIsNotNull()); column_names.insert(std::make_pair(column_def->GetColumnName(), add_column_desc)); - openmldb::type::DataType data_type; - if (!openmldb::schema::SchemaAdapter::ConvertType(column_def->GetColumnType(), &data_type)) { - status->msg = "column type " + - hybridse::node::DataTypeName(column_def->GetColumnType()) + " is not supported"; + auto s = openmldb::schema::SchemaAdapter::ConvertType(column_def->schema(), + add_column_desc->mutable_schema()); + if (!s.ok()) { + status->msg = s.ToString(); status->code = hybridse::common::kUnsupportSql; return false; } - add_column_desc->set_data_type(data_type); + add_column_desc->set_data_type(add_column_desc->schema().type()); + add_column_desc->set_not_null(add_column_desc->schema().not_null()); + auto default_val = column_def->GetDefaultValue(); if (default_val) { if (default_val->GetExprType() != hybridse::node::kExprPrimary) { diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 3dc11369fea..742cc82fe5e 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -62,6 +62,7 @@ #include "sdk/split.h" #include "udf/udf.h" #include "vm/catalog.h" +#include "codegen/insert_row_builder.h" DECLARE_string(bucket_size); DECLARE_uint32(replica_num); @@ -390,7 +391,8 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, std::vector* default_maps, - std::vector* str_lengths, bool* put_if_absent) { + std::vector* str_lengths, bool* put_if_absent, + std::vector>* codegen_rows) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); // TODO(hw): return status? RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr"); @@ -432,6 +434,17 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s return false; } std::map column_map; + + // codegen is the new approach encoding insert rows + // TODO(someone): initiate a session variable to control row encode implementation + bool insert_codegen = false; + for (int i = 0; i < (*table_info)->column_desc_size(); ++i) { + auto& col_desc = (*table_info)->column_desc(i); + if (!codec::ColumnSupportLegacyCodec(col_desc)) { + insert_codegen = true; + break; + } + } for (size_t j = 0; j < insert_stmt->columns_.size(); ++j) { const std::string& col_name = insert_stmt->columns_[j]; bool find_flag = false; @@ -443,6 +456,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s } column_map.insert(std::make_pair(i, j)); find_flag = true; + break; } } @@ -452,6 +466,24 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s return false; } } + + ::hybridse::codec::Schema sc; + if (!schema::SchemaAdapter::ConvertSchema((*table_info)->column_desc(), &sc)) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "failed to convert table schema"); + return false; + } + // TODO(someone): + // 1. default value from table definition + // 2. parameters + ::hybridse::codegen::InsertRowBuilder insert_builder(&sc); + { + auto s = insert_builder.Init(); + if (!s.ok()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, s.ToString()); + return false; + } + } + size_t total_rows_size = insert_stmt->values_.size(); for (size_t i = 0; i < total_rows_size; i++) { hybridse::node::ExprNode* value = insert_stmt->values_[i]; @@ -462,23 +494,39 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s hybridse::node::ExprTypeName(value->GetExprType())); return false; } - uint32_t str_length = 0; - default_maps->push_back( - GetDefaultMap(*table_info, column_map, dynamic_cast<::hybridse::node::ExprListNode*>(value), &str_length)); - if (!default_maps->back()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - "fail to parse row[" + std::to_string(i) + "]: " + value->GetExprString()); - return false; + if (insert_codegen) { + auto s = insert_builder.ComputeRow(dynamic_cast<::hybridse::node::ExprListNode*>(value)); + if (!s.ok()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + absl::Substitute("$2. Fail to encode row[$0]: $1.", i, s.status().ToString(), + value->GetExprString())); + return false; + } + + codegen_rows->push_back(s.value()); + continue; + } else { + uint32_t str_length = 0; + default_maps->push_back(GetDefaultMap(*table_info, column_map, + dynamic_cast<::hybridse::node::ExprListNode*>(value), &str_length)); + if (!default_maps->back()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "fail to parse row[" + std::to_string(i) + "]: " + value->GetExprString()); + return false; + } + str_lengths->push_back(str_length); } - str_lengths->push_back(str_length); } - if (default_maps->empty() || str_lengths->empty()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default_maps or str_lengths are empty"); - return false; - } - if (default_maps->size() != str_lengths->size()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default maps isn't match with str_lengths"); - return false; + + if (!insert_codegen) { + if (default_maps->empty() || str_lengths->empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default_maps or str_lengths are empty"); + return false; + } + if (default_maps->size() != str_lengths->size()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default maps isn't match with str_lengths"); + return false; + } } return true; } @@ -1216,7 +1264,9 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s std::vector default_maps; std::vector str_lengths; bool put_if_absent = false; - if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent)) { + std::vector> codegen_rows; + if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent, + &codegen_rows)) { CODE_PREPEND_AND_WARN(status, StatusCode::kCmdError, "Fail to get insert info"); return false; } @@ -1229,19 +1279,34 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s "Fail to execute insert statement: fail to get " + table_info->name() + " tablets"); return false; } + std::vector fails; - for (size_t i = 0; i < default_maps.size(); i++) { - auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i], put_if_absent); - if (!row || !row->Init(0) || !row->IsComplete()) { - // TODO(hw): SQLInsertRow or DefaultValueMap needs print helper function - LOG(WARNING) << "fail to build row[" << i << "]"; - fails.push_back(i); - continue; + if (!codegen_rows.empty()) { + for (size_t i = 0 ; i < codegen_rows.size(); ++i) { + auto r = codegen_rows[i]; + auto row = std::make_shared(table_info, schema, r, put_if_absent); + if (!PutRow(table_info->tid(), row, tablets, status)) { + LOG(WARNING) << "fail to put row[" + << "] due to: " << status->msg; + fails.push_back(i); + continue; + } } - if (!PutRow(table_info->tid(), row, tablets, status)) { - LOG(WARNING) << "fail to put row[" << i << "] due to: " << status->msg; - fails.push_back(i); - continue; + } else { + for (size_t i = 0; i < default_maps.size(); i++) { + auto row = + std::make_shared(table_info, schema, default_maps[i], str_lengths[i], put_if_absent); + if (!row || !row->Init(0) || !row->IsComplete()) { + // TODO(hw): SQLInsertRow or DefaultValueMap needs print helper function + LOG(WARNING) << "fail to build row[" << i << "]"; + fails.push_back(i); + continue; + } + if (!PutRow(table_info->tid(), row, tablets, status)) { + LOG(WARNING) << "fail to put row[" << i << "] due to: " << status->msg; + fails.push_back(i); + continue; + } } } if (!fails.empty()) { diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index 1226ee4f987..154a53f17d6 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -323,7 +323,7 @@ class SQLClusterRouter : public SQLRouter { bool GetMultiRowInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, std::vector* default_maps, std::vector* str_lengths, - bool* put_if_absent); + bool* put_if_absent, std::vector>* codegen_rows); DefaultValueMap GetDefaultMap(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info, const std::map& column_map, ::hybridse::node::ExprListNode* row, diff --git a/src/sdk/sql_insert_row.cc b/src/sdk/sql_insert_row.cc index 492bb80e49b..2f74d9fc330 100644 --- a/src/sdk/sql_insert_row.cc +++ b/src/sdk/sql_insert_row.cc @@ -90,6 +90,9 @@ SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> ta } bool SQLInsertRow::Init(int str_length) { + if (is_codegen_row_) { + return true; + } str_size_ = str_length + default_string_length_; uint32_t row_size = rb_.CalTotalLength(str_size_); val_.resize(row_size); @@ -301,7 +304,12 @@ bool SQLInsertRow::AppendNULL() { return false; } -bool SQLInsertRow::IsComplete() { return rb_.IsComplete(); } +bool SQLInsertRow::IsComplete() { + if (is_codegen_row_) { + return true; + } + return rb_.IsComplete(); +} bool SQLInsertRow::Build() const { return str_size_ == 0; } diff --git a/src/sdk/sql_insert_row.h b/src/sdk/sql_insert_row.h index af18891587f..b6e40de730c 100644 --- a/src/sdk/sql_insert_row.h +++ b/src/sdk/sql_insert_row.h @@ -110,6 +110,38 @@ class SQLInsertRow { SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent); + SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, + std::shared_ptr schema, std::shared_ptr codegen_row, bool put_if_absent) + : table_info_(table_info), + schema_(schema), + rb_(table_info->column_desc()), + put_if_absent_(put_if_absent), + is_codegen_row_(true) { + auto size = hybridse::codec::RowView::GetSize(codegen_row.get()); + val_ = std::string(reinterpret_cast(codegen_row.get()), size); + std::map column_name_map; + for (int idx = 0; idx < table_info_->column_desc_size(); idx++) { + column_name_map.emplace(table_info_->column_desc(idx).name(), idx); + } + if (table_info_->column_key_size() > 0) { + index_map_.clear(); + raw_dimensions_.clear(); + for (int idx = 0; idx < table_info_->column_key_size(); ++idx) { + const auto& index = table_info_->column_key(idx); + if (index.flag()) { + continue; + } + for (const auto& column : index.col_name()) { + index_map_[idx].push_back(column_name_map[column]); + raw_dimensions_[column_name_map[column]] = hybridse::codec::NONETOKEN; + } + if (!index.ts_name().empty()) { + ts_set_.insert(column_name_map[index.ts_name()]); + } + } + } + } + ~SQLInsertRow() = default; bool Init(int str_length); bool AppendBool(bool val); @@ -181,6 +213,8 @@ class SQLInsertRow { std::string val_; uint32_t str_size_; bool put_if_absent_; + + bool is_codegen_row_ = false; }; class SQLInsertRows {