diff --git a/.gitignore b/.gitignore index 44aaca1..361acd7 100644 --- a/.gitignore +++ b/.gitignore @@ -43,9 +43,8 @@ output/* testdata/ # Files produced by run.sh -kitex_gen/ kitex_gen_slim/ -grpc_gen/ +*_gen/ go.mod go.sum bin diff --git a/CREDITS b/CREDITS index e69de29..56dfc7a 100644 --- a/CREDITS +++ b/CREDITS @@ -0,0 +1,14 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + diff --git a/codegen/.gitignore b/codegen/.gitignore deleted file mode 100644 index f15c7e0..0000000 --- a/codegen/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -errors.txt -kitex_gen/ \ No newline at end of file diff --git a/codegen/codegen_run.sh b/codegen/codegen_run.sh index ca752e0..d843057 100755 --- a/codegen/codegen_run.sh +++ b/codegen/codegen_run.sh @@ -29,6 +29,7 @@ function check_cmd { "kitex -module codegen-test -thrift template=slim $filename" "kitex -module codegen-test -thrift keep_unknown_fields $filename" "kitex -module codegen-test -thrift template=slim -thrift keep_unknown_fields $filename" + "kitex -module codegen-test -thrift with_field_mask -thrift with_reflection -thrift keep_unknown_fields $filename" ) skip_error_info=( @@ -43,7 +44,8 @@ function check_cmd { if ! grep -q -E "$(printf '%s\n' "${skip_error_info[@]}" | paste -sd '|' -)" "$tmp_file"; then echo "$cmd" >> "errors.txt" cat "$tmp_file" >> "errors.txt" # 将错误输出添加到错误文件中 - echo "Error: $cmd" + echo "Kitex Error: $cmd" + exit 1 fi rm "$tmp_file" # 删除临时文件 continue @@ -52,27 +54,49 @@ function check_cmd { rm "$tmp_file" # 删除临时文件 # go mod 不展示输出,会干扰看结果,如果这一步出问题了,下一步 go build 会报错,所以不用担心 + go get github.com/cloudwego/thriftgo@main 2>&1 go mod tidy > /dev/null 2>&1 + # 验证编译 local tmp_file_2=$(mktemp) # 创建一个临时文件来存储输出 if ! eval "go build ./..." > "$tmp_file_2" 2>&1; then echo "$cmd" >> "errors.txt" cat "$tmp_file_2" >> "errors.txt" # 将错误输出添加到错误文件中 - echo "Error: $cmd" + echo "Go Error: $cmd" + exit 2 fi rm "$tmp_file_2" # 删除临时文件 done } +function run_test { + echo "run test..." + + #run fieldmask tests... + cd fieldmask + sh run.sh + if test $? != 0 + then + echo "run fieldmask test failed" + exit 3 + fi + cd .. +} + function main { + mkdir -p testdata + cd testdata + + clean_codegen + if [ -e "errors.txt" ]; then rm errors.txt fi touch errors.txt - basic_file_dir="basic_idls" + basic_file_dir="../basic_idls" basic_files=($(find "$basic_file_dir" -name "*.thrift" -type f -print)) basic_total=${#basic_files[@]} echo "starting test" @@ -92,6 +116,11 @@ function main { cat errors.txt exit 1 fi + + cd .. + + run_test + } main \ No newline at end of file diff --git a/codegen/fieldmask/baseline.thrift b/codegen/fieldmask/baseline.thrift new file mode 100644 index 0000000..7a33de0 --- /dev/null +++ b/codegen/fieldmask/baseline.thrift @@ -0,0 +1,76 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace go baseline + +struct Simple { + 1: byte ByteField + 2: i64 I64Field (api.js_conv = "") + 3: double DoubleField + 4: i32 I32Field + 5: string StringField, + 6: binary BinaryField +} + +struct PartialSimple { + 1: byte ByteField + 3: double DoubleField + 6: binary BinaryField +} + +struct Nesting { + 1: string String (api.header = "String") + 2: list ListSimple + 3: double Double (api.path = "double") + 4: i32 I32 (api.http_code = "", api.body = "I32") + 5: list ListI32 (api.query = "ListI32") + 6: i64 I64 + 7: map MapStringString + 8: Simple SimpleStruct + 9: map MapI32I64 + 10: list ListString + 11: binary Binary + 12: map MapI64String + 13: list ListI64 (api.cookie = "list_i64"), + 14: byte Byte + 15: map MapStringSimple +} + +struct PartialNesting { + 2: list ListSimple + 8: PartialSimple SimpleStruct + 15: map MapStringSimple +} + +struct Nesting2 { + 1: map MapSimpleNesting + 2: Simple SimpleStruct + 3: byte Byte + 4: double Double + 5: list ListNesting + 6: i64 I64 + 7: Nesting NestingStruct + 8: binary Binary + 9: string String + 10: set SetNesting + 11: i32 I32 +} + +service BaselineService { + Simple SimpleMethod(1: Simple req) (api.post = "/simple") + PartialSimple PartialSimpleMethod(1: PartialSimple req) + Nesting NestingMethod(1: Nesting req) (api.post = "/nesting") + PartialNesting PartialNestingMethod(1: PartialNesting req) + Nesting2 Nesting2Method(1: Nesting2 req) (api.post = "/nesting2") +} \ No newline at end of file diff --git a/codegen/fieldmask/baseline_test.go b/codegen/fieldmask/baseline_test.go new file mode 100644 index 0000000..5953c1b --- /dev/null +++ b/codegen/fieldmask/baseline_test.go @@ -0,0 +1,269 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bytes" + "math" + "strconv" + "strings" + "testing" + + "test/kitex_gen/baseline" + + "github.com/cloudwego/thriftgo/fieldmask" +) + +var ( + bytesCount int = 2 + stringCount int = 2 + listCount int = 16 +) + +func getString() string { + return strings.Repeat("你好,\b\n\r\t世界", stringCount) +} + +func getBytes() []byte { + return bytes.Repeat([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, bytesCount) +} + +func getSimpleValue() *baseline.Simple { + return &baseline.Simple{ + ByteField: math.MaxInt8, + I64Field: math.MaxInt64, + DoubleField: math.MaxFloat64, + I32Field: math.MaxInt32, + StringField: getString(), + BinaryField: getBytes(), + } +} + +func getNestingValue() *baseline.Nesting { + var ret = &baseline.Nesting{ + String_: getString(), + ListSimple: []*baseline.Simple{}, + Double: math.MaxFloat64, + I32: math.MaxInt32, + ListI32: []int32{}, + I64: math.MaxInt64, + MapStringString: map[string]string{}, + SimpleStruct: getSimpleValue(), + MapI32I64: map[int32]int64{}, + ListString: []string{}, + Binary: getBytes(), + MapI64String: map[int64]string{}, + ListI64: []int64{}, + Byte: math.MaxInt8, + MapStringSimple: map[string]*baseline.Simple{}, + } + + for i := 0; i < listCount; i++ { + ret.ListSimple = append(ret.ListSimple, getSimpleValue()) + ret.ListI32 = append(ret.ListI32, math.MinInt32) + ret.ListI64 = append(ret.ListI64, math.MinInt64) + ret.ListString = append(ret.ListString, getString()) + } + + for i := 0; i < listCount; i++ { + ret.MapStringString[strconv.Itoa(i)] = getString() + ret.MapI32I64[int32(i)] = math.MinInt64 + ret.MapI64String[int64(i)] = getString() + ret.MapStringSimple[strconv.Itoa(i)] = getSimpleValue() + } + + return ret +} + +func BenchmarkFastWriteSimple(b *testing.B) { + b.Run("full", func(b *testing.B) { + obj := getSimpleValue() + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = obj.BLength() + _ = obj.FastWriteNocopy(data, nil) + } + }) + b.Run("half", func(b *testing.B) { + obj := getSimpleValue() + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), "$.ByteField", "$.DoubleField", "$.StringField") + if err != nil { + b.Fatal(err) + } + obj.Set_FieldMask(fm) + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + _ = obj.BLength() + _ = obj.FastWriteNocopy(data, nil) + } + }) +} + +func BenchmarkFastReadSimple(b *testing.B) { + b.Run("full", func(b *testing.B) { + obj := getSimpleValue() + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + obj = baseline.NewSimple() + n, err := obj.FastRead(data) + if n != len(data) { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = obj.FastRead(data) + } + }) + b.Run("half", func(b *testing.B) { + obj := getSimpleValue() + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), + "$.ByteField", "$.DoubleField", "$.StringField") + if err != nil { + b.Fatal(err) + } + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + obj = baseline.NewSimple() + obj.Set_FieldMask(fm) + n, err := obj.FastRead(data) + if n != len(data) { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + _, _ = obj.FastRead(data) + } + }) +} + +func BenchmarkFastWriteNesting(b *testing.B) { + b.Run("full", func(b *testing.B) { + obj := getNestingValue() + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + // println("full data size: ", len(data)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = obj.BLength() + _ = obj.FastWriteNocopy(data, nil) + } + }) + b.Run("half", func(b *testing.B) { + obj := getNestingValue() + // ins := []string{} + // for i := 0; i < listCount/2; i++ { + // ins = append(ins, strconv.Itoa(i)) + // } + // is := strings.Join(ins, ",") + // ss := strings.Join(ins, `","`) + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), + // "$.ListSimple["+is+"]", "$.I32", "$.ListI32["+is+"]", `$.MapStringString{"`+ss+`"}`, + // "$.MapI32I64{"+is+"}", "$.Binary", "$.ListI64["+is+"]", `$.MapStringSimple{"`+ss+`"}`, + "$.ListSimple", "$.I32", "$.ListI32", `$.MapStringString`, + "$.MapI32I64", "$.Binary", + ) + if err != nil { + b.Fatal(err) + } + obj.Set_FieldMask(fm) + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + // println("half data size: ", len(data)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + _ = obj.BLength() + _ = obj.FastWriteNocopy(data, nil) + } + }) +} + +func BenchmarkFastReadNesting(b *testing.B) { + b.Run("full", func(b *testing.B) { + obj := getNestingValue() + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + obj = baseline.NewNesting() + n, err := obj.FastRead(data) + if n != len(data) { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = obj.FastRead(data) + } + }) + b.Run("half", func(b *testing.B) { + obj := getNestingValue() + // ins := []string{} + // for i := 0; i < listCount/2; i++ { + // ins = append(ins, strconv.Itoa(i)) + // } + // is := strings.Join(ins, ",") + // ss := strings.Join(ins, `","`) + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), + // "$.ListSimple["+is+"]", "$.I32", "$.ListI32["+is+"]", `$.MapStringString{"`+ss+`"}`, + // "$.MapI32I64{"+is+"}", "$.Binary", "$.ListI64["+is+"]", `$.MapStringSimple{"`+ss+`"}`, + "$.ListSimple", "$.I32", "$.ListI32", `$.MapStringString`, + "$.MapI32I64", "$.Binary", + ) + if err != nil { + b.Fatal(err) + } + data := make([]byte, obj.BLength()) + ret := obj.FastWriteNocopy(data, nil) + if ret != len(data) { + b.Fatal(ret) + } + obj = baseline.NewNesting() + obj.Set_FieldMask(fm) + n, err := obj.FastRead(data) + if n != len(data) { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + _, _ = obj.FastRead(data) + } + }) +} diff --git a/codegen/fieldmask/main_test.go b/codegen/fieldmask/main_test.go new file mode 100644 index 0000000..4648d1f --- /dev/null +++ b/codegen/fieldmask/main_test.go @@ -0,0 +1,553 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "runtime" + "sync" + hbase "test/halfway_gen/base" + nbase "test/kitex_gen/base" + obase "test/old_gen/base" + zbase "test/zero_gen/base" + "testing" + + "github.com/cloudwego/thriftgo/fieldmask" + "github.com/stretchr/testify/require" +) + +var fieldmaskCache sync.Map + +func init() { + // new a obj to get its TypeDescriptor + obj := nbase.NewBase() + + // construct a fieldmask with TypeDescriptor and thrift paths + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), + "$.Enum", "$.EnumMap{1}", "$.LogID", "$.TrafficEnv.Name", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", + "$.Meta.List[1]", "$.Meta.Set[0].id", "$.Meta.Set[1].name") + if err != nil { + panic(err) + } + + // cache it for future usage of nbase.Base + fieldmaskCache.Store("Mask1ForBase", fm) + + fm2, err := fieldmask.Options{BlackListMode: true}.NewFieldMask(obj.GetTypeDescriptor(), + "$.Enum", "$.EnumMap{1}", "$.LogID", "$.TrafficEnv.Name", "$.TrafficEnv.Code", "$.Meta.IntMap{1}", "$.Meta.StrMap{\"1234\"}", + "$.Meta.List[1]", "$.Meta.Set[0].id", "$.Meta.Set[1].name") + if err != nil { + panic(err) + } + fieldmaskCache.Store("Mask1ForBase-Black", fm2) +} + +func SampleNewBase() *nbase.Base { + obj := nbase.NewBase() + obj.Addr = "abcd" + obj.Caller = "abcd" + obj.LogID = "abcd" + obj.Enum = nbase.Ex_A + obj.EnumMap = map[nbase.Ex]string{ + nbase.Ex_A: "a", + nbase.Ex_B: "b", + } + obj.Meta = nbase.NewMetaInfo() + obj.Meta.StrMap = map[string]*nbase.Val{ + "abcd": nbase.NewVal(), + "1234": nbase.NewVal(), + } + obj.Meta.IntMap = map[int64]*nbase.Val{ + 1: nbase.NewVal(), + 2: nbase.NewVal(), + } + v0 := nbase.NewVal() + v0.Id = "a" + v0.Name = "a" + v1 := nbase.NewVal() + v1.Id = "b" + v1.Name = "b" + obj.Meta.List = []*nbase.Val{v0, v1} + // v0 = nbase.NewVal() + // v0.ID = "a" + // v0.Name = "a" + // v1 = nbase.NewVal() + // v1.ID = "b" + // v1.Name = "b" + obj.Meta.Set = []*nbase.Val{v0, v1} + // obj.Extra = nbase.NewExtraInfo() + obj.TrafficEnv = nbase.NewTrafficEnv() + obj.TrafficEnv.Code = 1 + obj.TrafficEnv.Env = "abcd" + obj.TrafficEnv.Name = "abcd" + obj.TrafficEnv.Open = true + obj.Meta.Base = nbase.NewBase() + return obj +} + +func SampleOldBase() *obase.Base { + obj := obase.NewBase() + obj.Addr = "abcd" + obj.Caller = "abcd" + obj.LogID = "abcd" + obj.Meta = obase.NewMetaInfo() + obj.Meta.StrMap = map[string]*obase.Val{ + "abcd": obase.NewVal(), + "1234": obase.NewVal(), + } + obj.Meta.IntMap = map[int64]*obase.Val{ + 1: obase.NewVal(), + 2: obase.NewVal(), + } + v0 := obase.NewVal() + v0.Id = "a" + v0.Name = "a" + v1 := obase.NewVal() + v1.Id = "b" + v1.Name = "b" + obj.Meta.List = []*obase.Val{v0, v1} + obj.Meta.Set = []*obase.Val{v0, v1} + // obj.Extra = obase.NewExtraInfo() + obj.TrafficEnv = obase.NewTrafficEnv() + obj.TrafficEnv.Code = 1 + obj.TrafficEnv.Env = "abcd" + obj.TrafficEnv.Name = "abcd" + obj.TrafficEnv.Open = true + obj.Meta.Base = obase.NewBase() + return obj +} + +func TestFastWrite_WhiteList(t *testing.T) { + var obj = SampleNewBase() + // Load fieldmask from cache + fm, _ := fieldmaskCache.Load("Mask1ForBase") + if fm != nil { + // load ok, set fieldmask onto the object using codegen API + obj.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + out := make([]byte, obj.BLength()) + // out := make([]byte, 24000000) + e := obj.FastWriteNocopy(out, nil) + var obj2 = nbase.NewBase() + n, err := obj2.FastRead(out) + if err != nil { + t.Fatal(err) + } + require.Equal(t, e, n) + + require.Equal(t, obj.Addr, obj2.Addr) + require.Equal(t, obj.Enum, obj2.Enum) + require.Equal(t, map[nbase.Ex]string{nbase.Ex_A: "a"}, obj2.EnumMap) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, obj.TrafficEnv.Name, obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].Id, obj2.Meta.IntMap[1].Id) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].Id, obj2.Meta.StrMap["1234"].Id) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.List[0].Id) + require.Equal(t, "b", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "a", obj2.Meta.Set[0].Id) + require.Equal(t, "", obj2.Meta.Set[0].Name) + require.Equal(t, "", obj2.Meta.Set[1].Id) + require.Equal(t, "b", obj2.Meta.Set[1].Name) +} + +func TestFastWrite_BlackList(t *testing.T) { + var obj = SampleNewBase() + // Load fieldmask from cache + fm, _ := fieldmaskCache.Load("Mask1ForBase-Black") + if fm != nil { + // load ok, set fieldmask onto the object using codegen API + obj.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + out := make([]byte, obj.BLength()) + // out := make([]byte, 24000000) + e := obj.FastWriteNocopy(out, nil) + var obj2 = nbase.NewBase() + n, err := obj2.FastRead(out) + if err != nil { + t.Fatal(err) + } + require.Equal(t, e, n) + + require.Equal(t, obj.Addr, obj2.Addr) + require.Equal(t, nbase.Ex(0), obj2.Enum) + require.Equal(t, map[nbase.Ex]string{nbase.Ex_B: "b"}, obj2.EnumMap) + require.Equal(t, "", obj2.LogID) + require.Equal(t, obj.Caller, obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, obj.TrafficEnv.Open, obj2.TrafficEnv.Open) + require.Equal(t, obj.TrafficEnv.Env, obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) // required + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[1]) + require.Equal(t, obj.Meta.IntMap[0], obj2.Meta.IntMap[0]) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["1234"]) + require.Equal(t, obj.Meta.StrMap["abcd"], obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "a", obj2.Meta.List[0].Id) + require.Equal(t, "a", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "", obj2.Meta.Set[0].Id) + require.Equal(t, "a", obj2.Meta.Set[0].Name) + require.Equal(t, "b", obj2.Meta.Set[1].Id) + require.Equal(t, "", obj2.Meta.Set[1].Name) +} + +func TestFastRead_WhiteList(t *testing.T) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + e := obj.FastWriteNocopy(buf, nil) + obj2 := nbase.NewBase() + fm, _ := fieldmaskCache.Load("Mask1ForBase") + if fm != nil { + obj2.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + + n, err := obj2.FastRead(buf) + if err != nil { + t.Fatal(err) + } + require.Equal(t, e, n) + + require.Equal(t, "", obj2.Addr) + require.Equal(t, obj.Enum, obj2.Enum) + require.Equal(t, map[nbase.Ex]string{nbase.Ex_A: "a"}, obj2.EnumMap) + require.Equal(t, obj.LogID, obj2.LogID) + require.Equal(t, "", obj2.Caller) + require.Equal(t, obj.TrafficEnv.Name, obj2.TrafficEnv.Name) + require.Equal(t, false, obj2.TrafficEnv.Open) + require.Equal(t, "", obj2.TrafficEnv.Env) + require.Equal(t, obj.TrafficEnv.Code, obj2.TrafficEnv.Code) + require.Equal(t, obj.Meta.IntMap[1].Id, obj2.Meta.IntMap[1].Id) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[0]) + require.Equal(t, obj.Meta.StrMap["1234"].Id, obj2.Meta.StrMap["1234"].Id) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "b", obj2.Meta.List[0].Id) + require.Equal(t, "b", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "a", obj2.Meta.Set[0].Id) + require.Equal(t, "", obj2.Meta.Set[0].Name) + require.Equal(t, "", obj2.Meta.Set[1].Id) + require.Equal(t, "b", obj2.Meta.Set[1].Name) +} + +func TestFastRead_BlackList(t *testing.T) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + e := obj.FastWriteNocopy(buf, nil) + obj2 := nbase.NewBase() + fm, _ := fieldmaskCache.Load("Mask1ForBase-Black") + if fm != nil { + obj2.Set_FieldMask(fm.(*fieldmask.FieldMask)) + } + + n, err := obj2.FastRead(buf) + if err != nil { + t.Fatal(err) + } + require.Equal(t, e, n) + + require.Equal(t, obj.Addr, obj2.Addr) + require.Equal(t, nbase.Ex(0), obj2.Enum) + require.Equal(t, map[nbase.Ex]string{nbase.Ex_B: "b"}, obj2.EnumMap) + require.Equal(t, "", obj2.LogID) + require.Equal(t, obj.Caller, obj2.Caller) + require.Equal(t, "", obj2.TrafficEnv.Name) + require.Equal(t, obj.TrafficEnv.Open, obj2.TrafficEnv.Open) + require.Equal(t, obj.TrafficEnv.Env, obj2.TrafficEnv.Env) + require.Equal(t, int64(0), obj2.TrafficEnv.Code) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.IntMap[1]) + require.Equal(t, obj.Meta.IntMap[0], obj2.Meta.IntMap[0]) + require.Equal(t, (*nbase.Val)(nil), obj2.Meta.StrMap["1234"]) + require.Equal(t, obj.Meta.StrMap["abcd"], obj2.Meta.StrMap["abcd"]) + require.Equal(t, 1, len(obj2.Meta.List)) + require.Equal(t, "a", obj2.Meta.List[0].Id) + require.Equal(t, "a", obj2.Meta.List[0].Name) + require.Equal(t, 2, len(obj2.Meta.Set)) + require.Equal(t, "", obj2.Meta.Set[0].Id) + require.Equal(t, "a", obj2.Meta.Set[0].Name) + require.Equal(t, "b", obj2.Meta.Set[1].Id) + require.Equal(t, "", obj2.Meta.Set[1].Name) +} + +func TestMaskRequired(t *testing.T) { + fm, err := fieldmask.NewFieldMask(nbase.NewBaseResp().GetTypeDescriptor(), "$.F1", "$.F8", "$.R12.Env") + if err != nil { + t.Fatal(err) + } + j, err := fm.MarshalJSON() + if err != nil { + t.Fatal(err) + } + println(string(j)) + nf, ex := fm.Field(111) + if !ex { + t.Fatal(nf) + } + + t.Run("read", func(t *testing.T) { + obj := nbase.NewBaseResp() + obj.F1 = map[nbase.Str]nbase.Str{"a": "b"} + obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}} + obj.R12 = nbase.NewTrafficEnv() + obj.R12.Name = "a" + obj.R12.Env = "a" + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err != len(buf) { + t.Fatal(err) + } + obj2 := nbase.NewBaseResp() + obj2.Set_FieldMask(fm) + if _, err := obj2.FastRead(buf); err != nil { + t.Fatal(err) + } + require.Equal(t, obj.F1, obj2.F1) + require.Equal(t, obj.F8, obj2.F8) + require.Equal(t, "", obj2.R12.Name) + require.Equal(t, obj.R12.Env, obj2.R12.Env) + }) + + t.Run("write current", func(t *testing.T) { + obj := nbase.NewBaseResp() + obj.StatusCode = 1 + obj.R3 = true + obj.R4 = 1 + obj.R5 = 1 + obj.R6 = 1 + obj.R7 = 1 + obj.R8 = "R8" + obj.R9 = nbase.Ex_B + v := nbase.NewVal() + v.Id = "a" + obj.R10 = []*nbase.Val{v} + obj.R11 = []*nbase.Val{v} + obj.R12 = nbase.NewTrafficEnv() + obj.R12.Name = "a" + obj.R12.Env = "a" + obj.R13 = map[string]*nbase.Key{"a": v} + obj.F1 = map[nbase.Str]nbase.Str{"a": "b"} + obj.F8 = map[float64][]nbase.Str{1.0: []nbase.Str{"a"}} + obj.Set_FieldMask(fm) + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err != len(buf) { + t.Fatal(err) + } + obj2 := nbase.NewBaseResp() + if _, err := obj2.FastRead(buf); err != nil { + t.Fatal(err) + } + + require.Equal(t, obj.F1, obj2.F1) + require.Equal(t, obj.F8, obj2.F8) + require.Equal(t, obj.StatusCode, obj2.StatusCode) + require.Equal(t, obj.R3, obj2.R3) + require.Equal(t, obj.R4, obj2.R4) + require.Equal(t, obj.R5, obj2.R5) + require.Equal(t, obj.R6, obj2.R6) + require.Equal(t, obj.R7, obj2.R7) + require.Equal(t, obj.R8, obj2.R8) + require.Equal(t, obj.R9, obj2.R9) + require.Equal(t, obj.R10, obj2.R10) + require.Equal(t, obj.R11, obj2.R11) + require.Equal(t, "", obj2.R12.Name) + require.Equal(t, obj.R12.Env, obj2.R12.Env) + require.Equal(t, obj.R13, obj2.R13) + }) + + t.Run("write zero", func(t *testing.T) { + fm, err := fieldmask.NewFieldMask(nbase.NewBaseResp().GetTypeDescriptor(), "$.F1", "$.F8", "$.R12") + if err != nil { + t.Fatal(err) + } + obj := zbase.NewBaseResp() + obj.F1 = map[zbase.Str]zbase.Str{"a": "b"} + obj.F8 = map[float64][]zbase.Str{1.0: []zbase.Str{"a"}} + obj.StatusCode = 1 + obj.R3 = true + obj.R4 = 1 + obj.R5 = 1 + obj.R6 = 1 + obj.R7 = 1 + obj.R8 = "R8" + obj.R9 = zbase.Ex_B + v := zbase.NewVal() + v.Id = "a" + obj.R10 = []*zbase.Val{v} + obj.R11 = []*zbase.Val{v} + obj.R12 = zbase.NewTrafficEnv() + obj.R12.Name = "a" + obj.R12.Env = "a" + obj.R13 = map[string]*zbase.Key{"a": v} + + obj.Set_FieldMask(fm) + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err != len(buf) { + t.Fatal(err) + } + obj2 := zbase.NewBaseResp() + if _, err := obj2.FastRead(buf); err != nil { + t.Fatal(err) + } + + require.Equal(t, obj.F1, obj2.F1) + require.Equal(t, obj.F8, obj2.F8) + require.Equal(t, int32(0), obj2.StatusCode) + require.Equal(t, false, obj2.R3) + require.Equal(t, int8(0), obj2.R4) + require.Equal(t, int16(0), obj2.R5) + require.Equal(t, int64(0), obj2.R6) + require.Equal(t, float64(0), obj2.R7) + require.Equal(t, "", obj2.R8) + require.Equal(t, zbase.Ex(0), obj2.R9) + require.Equal(t, []*zbase.Val{}, obj2.R10) + require.Equal(t, []*zbase.Val{}, obj2.R11) + obj.R12.Set_FieldMask(nil) + require.Equal(t, obj.R12, obj2.R12) + require.Equal(t, map[string]*zbase.Key{}, obj2.R13) + }) +} + +func TestSetMaskHalfway(t *testing.T) { + obj := hbase.NewBase() + obj.Extra = hbase.NewExtraInfo() + obj.Extra.F1 = map[string]string{"a": "b"} + obj.Extra.F8 = map[int64][]*hbase.Key{1: []*hbase.Key{hbase.NewKey()}} + + fm, err := fieldmask.NewFieldMask(obj.Extra.GetTypeDescriptor(), "$.F1", "$.F8") + if err != nil { + t.Fatal(err) + } + obj.Extra.Set_FieldMask(fm) + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err != len(buf) { + t.Fatal(err) + } + obj2 := hbase.NewBase() + if _, err := obj2.FastRead(buf); err != nil { + t.Fatal(err) + } + require.Equal(t, obj.Extra.F1, obj2.Extra.F1) + require.Equal(t, obj.Extra.F8, obj2.Extra.F8) +} + +func BenchmarkFastWriteWithFieldMask(b *testing.B) { + b.Run("old", func(b *testing.B) { + obj := SampleOldBase() + buf := make([]byte, obj.BLength()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf := buf[:obj.BLength()] + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + } + }) + + runtime.GC() + + b.Run("new", func(b *testing.B) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf := buf[:obj.BLength()] + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + } + }) + + runtime.GC() + + b.Run("new-mask-half", func(b *testing.B) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap", "$.Meta.List") + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + buf := buf[:obj.BLength()] + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + } + }) +} + +func BenchmarkFastReadWithFieldMask(b *testing.B) { + b.Run("old", func(b *testing.B) { + obj := SampleOldBase() + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + obj = obase.NewBase() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := obj.FastRead(buf); err != nil { + b.Fatal(err) + } + } + }) + + runtime.GC() + + b.Run("new", func(b *testing.B) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + obj = nbase.NewBase() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := obj.FastRead(buf); err != nil { + b.Fatal(err) + } + } + }) + + runtime.GC() + + b.Run("new-mask-half", func(b *testing.B) { + obj := SampleNewBase() + buf := make([]byte, obj.BLength()) + if err := obj.FastWriteNocopy(buf, nil); err == 0 { + b.Fatal(err) + } + obj = nbase.NewBase() + + fm, err := fieldmask.NewFieldMask(obj.GetTypeDescriptor(), "$.Addr", "$.LogID", "$.TrafficEnv.Code", "$.Meta.IntMap", "$.Meta.List") + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + obj.Set_FieldMask(fm) + if _, err := obj.FastRead(buf); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/codegen/fieldmask/nesting.prof b/codegen/fieldmask/nesting.prof new file mode 100644 index 0000000..0ad3ffe Binary files /dev/null and b/codegen/fieldmask/nesting.prof differ diff --git a/codegen/fieldmask/run.sh b/codegen/fieldmask/run.sh new file mode 100644 index 0000000..be03718 --- /dev/null +++ b/codegen/fieldmask/run.sh @@ -0,0 +1,25 @@ +# Copyright 2023 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash + +kitex -module test -gen-path old_gen test_fieldmask.thrift +kitex -module test -thrift with_field_mask -thrift with_reflection test_fieldmask.thrift +cat test_fieldmask.thrift > test_fieldmask2.thrift +cat test_fieldmask.thrift > test_fieldmask3.thrift +kitex -module test -thrift with_field_mask -thrift field_mask_zero_required -thrift with_reflection -gen-path zero_gen test_fieldmask2.thrift +kitex -module test -thrift with_field_mask -thrift field_mask_halfway -thrift with_reflection -gen-path halfway_gen test_fieldmask3.thrift +kitex -module test -thrift with_field_mask -thrift with_reflection baseline.thrift +go mod tidy +go test ./... diff --git a/codegen/fieldmask/test_fieldmask.thrift b/codegen/fieldmask/test_fieldmask.thrift new file mode 100644 index 0000000..b085e00 --- /dev/null +++ b/codegen/fieldmask/test_fieldmask.thrift @@ -0,0 +1,106 @@ +#! /bin/bash -e + +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: required i64 Code, +} + +struct Base { + 0: required string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 9: Ex Enum, + 10: map EnumMap, + 255: optional ExtraInfo Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map F1 + 2: map F2, + 3: list F3 + 4: set F4, + 5: map F5 + 6: map F6 + 7: map> F7 + 8: map> F8 + 9: map>> F9 + 10: map F10 +} + +struct MetaInfo { + 1: map IntMap, + 2: map StrMap, + 3: list List, + 4: set Set, + 11: map> MapList + 12: list>> ListMapList + 255: Base Base, +} + +typedef Val Key + +struct Val { + 1: string id + 2: string name +} + +typedef double Float + +typedef i64 Int + +typedef string Str + +enum Ex { + A = 1, + B = 2, + C = 3 +} + +struct BaseResp { + 1: optional string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, + + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 +} + diff --git a/codegen/fieldmask/test_fieldmask2.thrift b/codegen/fieldmask/test_fieldmask2.thrift new file mode 100644 index 0000000..b085e00 --- /dev/null +++ b/codegen/fieldmask/test_fieldmask2.thrift @@ -0,0 +1,106 @@ +#! /bin/bash -e + +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: required i64 Code, +} + +struct Base { + 0: required string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 9: Ex Enum, + 10: map EnumMap, + 255: optional ExtraInfo Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map F1 + 2: map F2, + 3: list F3 + 4: set F4, + 5: map F5 + 6: map F6 + 7: map> F7 + 8: map> F8 + 9: map>> F9 + 10: map F10 +} + +struct MetaInfo { + 1: map IntMap, + 2: map StrMap, + 3: list List, + 4: set Set, + 11: map> MapList + 12: list>> ListMapList + 255: Base Base, +} + +typedef Val Key + +struct Val { + 1: string id + 2: string name +} + +typedef double Float + +typedef i64 Int + +typedef string Str + +enum Ex { + A = 1, + B = 2, + C = 3 +} + +struct BaseResp { + 1: optional string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, + + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 +} + diff --git a/codegen/fieldmask/test_fieldmask3.thrift b/codegen/fieldmask/test_fieldmask3.thrift new file mode 100644 index 0000000..b085e00 --- /dev/null +++ b/codegen/fieldmask/test_fieldmask3.thrift @@ -0,0 +1,106 @@ +#! /bin/bash -e + +# Copyright 2022 CloudWeGo Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +namespace go base + +struct TrafficEnv { + 0: string Name = "", + 1: bool Open = false, + 2: string Env = "", + 256: required i64 Code, +} + +struct Base { + 0: required string Addr = "", + 1: string LogID = "", + 2: string Caller = "", + 5: optional TrafficEnv TrafficEnv, + 9: Ex Enum, + 10: map EnumMap, + 255: optional ExtraInfo Extra, + 256: MetaInfo Meta, +} + +struct ExtraInfo { + 1: map F1 + 2: map F2, + 3: list F3 + 4: set F4, + 5: map F5 + 6: map F6 + 7: map> F7 + 8: map> F8 + 9: map>> F9 + 10: map F10 +} + +struct MetaInfo { + 1: map IntMap, + 2: map StrMap, + 3: list List, + 4: set Set, + 11: map> MapList + 12: list>> ListMapList + 255: Base Base, +} + +typedef Val Key + +struct Val { + 1: string id + 2: string name +} + +typedef double Float + +typedef i64 Int + +typedef string Str + +enum Ex { + A = 1, + B = 2, + C = 3 +} + +struct BaseResp { + 1: optional string StatusMessage = "", + 2: required i32 StatusCode = 0, + 3: required bool R3, + 4: required byte R4, + 5: required i16 R5, + 6: required i64 R6, + 7: required double R7, + 8: required string R8, + 9: required Ex R9, + 10: required list R10, + 11: required set R11, + 12: required TrafficEnv R12, + 13: required map R13, + 0: required Key R0, + + 14: map F1 + 15: map F2, + 16: list F3 + 17: set F4, + 18: map F5 + 19: map F6 + 110: map F7 + 111: map> F8 + 112: list>> F9 + 113: map F10 +} + diff --git a/idl/fieldmask.thrift b/idl/fieldmask.thrift new file mode 100644 index 0000000..2bcf48c --- /dev/null +++ b/idl/fieldmask.thrift @@ -0,0 +1,31 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace go fieldmask + +struct BizRequest { + 1: string A + 2: required string B + 3: optional binary RespMask +} + +struct BizResponse { + 1: string A + 2: required string B + 3: string C +} + +service BizService { + BizResponse BizMethod1(1: BizRequest req) +} \ No newline at end of file diff --git a/run.sh b/run.sh index ce46ca1..919da1b 100755 --- a/run.sh +++ b/run.sh @@ -67,9 +67,6 @@ which protoc-gen-go || go_install google.golang.org/protobuf/cmd/protoc-gen-go@l # install protoc-gen-go and protoc-gen-go-kitexgrpc which protoc-gen-go-grpc || go_install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest -# Install thriftgo -which thriftgo || go_install github.com/cloudwego/thriftgo@latest - # Install kitex and generate codes LOCAL_REPO=$1 @@ -78,10 +75,22 @@ if [[ -n $LOCAL_REPO ]]; then go_install ${LOCAL_REPO}/tool/cmd/kitex cd - else - go_install github.com/cloudwego/kitex/tool/cmd/kitex@latest + which kitex || go_install github.com/cloudwego/kitex/tool/cmd/kitex@latest +fi + +# Install thriftgo +THRIFTGO_REPO=$2 + +if [[ -n $THRIFTGO_REPO ]]; then + cd ${THRIFTGO_REPO} + go_install ${THRIFTGO_REPO} + cd - +else + which thriftgo || go_install github.com/cloudwego/thriftgo@latest fi test -d kitex_gen && rm -rf kitex_gen +kitex -module kitex -module github.com/cloudwego/kitex-tests -thrift with_field_mask -thrift with_reflection ./idl/fieldmask.thrift kitex -module github.com/cloudwego/kitex-tests ./idl/stability.thrift kitex -module github.com/cloudwego/kitex-tests ./idl/http.thrift kitex -module github.com/cloudwego/kitex-tests ./idl/tenant.thrift @@ -104,7 +113,6 @@ mkdir grpc_gen protoc --go_out=grpc_gen/. ./idl/grpc_demo_2.proto protoc --go-grpc_out=grpc_gen/. ./idl/grpc_demo_2.proto - # Init dependencies go get github.com/apache/thrift@v0.13.0 go get github.com/cloudwego/kitex@develop @@ -113,6 +121,10 @@ if [[ -n $LOCAL_REPO ]]; then go mod edit -replace github.com/cloudwego/kitex=${LOCAL_REPO} fi +if [[ -n $THRIFTGO_REPO ]]; then + go mod edit -replace github.com/cloudwego/thriftgo=${THRIFTGO_REPO} +fi + go mod tidy # static check @@ -122,6 +134,7 @@ go vet -stdmethods=false $(go list ./...) # run tests packages=( +./thriftrpc/fieldmask/... ./thriftrpc/normalcall/... ./thriftrpc/muxcall/... ./thriftrpc/retrycall/... diff --git a/thriftrpc/fieldmask/handler.go b/thriftrpc/fieldmask/handler.go new file mode 100644 index 0000000..62c3c79 --- /dev/null +++ b/thriftrpc/fieldmask/handler.go @@ -0,0 +1,55 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fieldmask + +import ( + "context" + "errors" + + fieldmask0 "github.com/cloudwego/kitex-tests/kitex_gen/fieldmask" + "github.com/cloudwego/thriftgo/fieldmask" +) + +// BizServiceImpl implements the last service interface defined in the IDL. +type BizServiceImpl struct{} + +// BizMethod1 implements the BizServiceImpl interface. +func (s *BizServiceImpl) BizMethod1(ctx context.Context, req *fieldmask0.BizRequest) (resp *fieldmask0.BizResponse, err error) { + // check if request has been masked + if req.A != "" { // req.A must be filtered + return nil, errors.New("request must filter BizRequest.A!") + } + if req.B == "" { // req.B must not be filtered + return nil, errors.New("request must not filter BizRequest.B!") + } + + resp = fieldmask0.NewBizResponse() + + // check if request carries a fieldmask + if req.RespMask != nil { + println("got fieldmask", string(req.RespMask)) + fm, err := fieldmask.Unmarshal(req.RespMask) + if err != nil { + return nil, err + } + // set fieldmask for response + resp.Set_FieldMask(fm) + } + + resp.A = "A" + resp.B = "B" + resp.C = "C" + return +} diff --git a/thriftrpc/fieldmask/main_test.go b/thriftrpc/fieldmask/main_test.go new file mode 100644 index 0000000..4eadd2a --- /dev/null +++ b/thriftrpc/fieldmask/main_test.go @@ -0,0 +1,163 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fieldmask + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + fieldmask0 "github.com/cloudwego/kitex-tests/kitex_gen/fieldmask" + "github.com/cloudwego/kitex-tests/kitex_gen/fieldmask/bizservice" + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/thriftgo/fieldmask" +) + +func TestMain(m *testing.M) { + addr, err := net.ResolveTCPAddr("tcp", ":8999") + if err != nil { + panic(err.Error()) + } + svr := bizservice.NewServer(new(BizServiceImpl), server.WithServiceAddr(addr)) + + go func() { + err = svr.Run() + if err != nil { + panic(err.Error()) + } + }() + + // initialize request and response fieldmasks and cache them + respMask, err := fieldmask.NewFieldMask((*fieldmask0.BizResponse)(nil).GetTypeDescriptor(), "$.A") + if err != nil { + panic(err) + } + fmCache.Store("BizResponse", respMask) + reqMask, err := fieldmask.NewFieldMask((*fieldmask0.BizRequest)(nil).GetTypeDescriptor(), "$.B", "$.RespMask") + if err != nil { + panic(err) + } + fmCache.Store("BizRequest", reqMask) + + // black list mod + respMaskBlack, err := fieldmask.Options{BlackListMode: true}.NewFieldMask((*fieldmask0.BizResponse)(nil).GetTypeDescriptor(), "$.A", "$.B") + if err != nil { + panic(err) + } + fmCache.Store("BizResponse-Black", respMaskBlack) + reqMaskBlack, err := fieldmask.Options{BlackListMode: true}.NewFieldMask((*fieldmask0.BizRequest)(nil).GetTypeDescriptor(), "$.A") + if err != nil { + panic(err) + } + fmCache.Store("BizRequest-Black", reqMaskBlack) + + time.Sleep(time.Second) + m.Run() + svr.Stop() +} + +var fmCache sync.Map + +func TestFieldMask(t *testing.T) { + cli, err := bizservice.NewClient("BizService", client.WithHostPorts(":8999")) + if err != nil { + t.Fatal(err) + } + + req := fieldmask0.NewBizRequest() + req.A = "A" + req.B = "B" + // try load request's fieldmask + reqMask, ok := fmCache.Load("BizRequest") + if ok { + req.Set_FieldMask(reqMask.(*fieldmask.FieldMask)) + } + + // try get response's fieldmask + respMask, ok := fmCache.Load("BizResponse") + if ok { + // serialize the respMask + fm, err := fieldmask.Marshal(respMask.(*fieldmask.FieldMask)) + if err != nil { + t.Fatal(err) + } + // let request carry fm + req.RespMask = fm + } + + resp, err := cli.BizMethod1(context.Background(), req) + if err != nil { + t.Fatal(err) + } + fmt.Printf("%#v\n", resp) + + if resp.A == "" { // resp.A in mask + t.Fail() + } + if resp.B == "" { // resp.B not in mask, but it's required, so still written + t.Fail() + } + if resp.C != "" { // resp.C not in mask + t.Fail() + } +} + +func TestFieldMask_BlackList(t *testing.T) { + cli, err := bizservice.NewClient("BizService", client.WithHostPorts(":8999")) + if err != nil { + t.Fatal(err) + } + + req := fieldmask0.NewBizRequest() + req.A = "A" + req.B = "B" + // try load request's fieldmask + reqMask, ok := fmCache.Load("BizRequest-Black") + if ok { + req.Set_FieldMask(reqMask.(*fieldmask.FieldMask)) + } + + // try get reponse's fieldmask + respMask, ok := fmCache.Load("BizResponse-Black") + if ok { + // serialize the respMask + fm, err := fieldmask.Marshal(respMask.(*fieldmask.FieldMask)) + if err != nil { + t.Fatal(err) + } + // let request carry fm + req.RespMask = fm + } + + resp, err := cli.BizMethod1(context.Background(), req) + if err != nil { + t.Fatal(err) + } + fmt.Printf("%#v\n", resp) + + if resp.A != "" { // resp.A in mask + t.Fail() + } + if resp.B == "" { // resp.B not in mask, but it's required, so still written + t.Fail() + } + if resp.C == "" { // resp.C not in mask + t.Fail() + } +}