From 0b2555c70ec1a81934d761be4dd6208549a3fbb0 Mon Sep 17 00:00:00 2001 From: lingbo Date: Tue, 4 Jul 2023 19:23:20 +0800 Subject: [PATCH] add feature nested pointer support(#21) --- README.md | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ defaults.go | 56 ++++++++++++++++++++++++++++------- defaults_test.go | 55 ++++++++++++++++++++++++++++++++++ filler.go | 20 +++++++++---- 4 files changed, 192 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index ce0a5c6..a176f05 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,83 @@ fmt.Println(example.Bar) //Prints: 0 ``` +Pointer Set +------- + +Pointer field struct is a tricky usage to avoid covering existed values. + +Take the basic example in the above section and change it slightly: +```go + +type ExamplePointer struct { + Foo *bool `default:"true"` //<-- StructTag with a default key + Bar *string `default:"example"` + Qux *int `default:"22"` + Oed *int64 `default:"64"` +} + +... + +boolZero := false +stringZero := "" +intZero := 0 +example := &ExamplePointer{ + Foo: &boolZero, + Bar: &stringZero, + Qux: &intZero, +} +defaults.SetDefaults(example) + +fmt.Println(*example.Foo) //Prints: false (zero value `false` for bool but not for bool ptr) +fmt.Println(*example.Bar) //Prints: "" (print "" which set in advance, not "example" for default) +fmt.Println(*example.Qux) //Prints: 0 (0 instead of 22) +fmt.Println(*example.Oed) //Prints: 64 (64, because the ptr addr is nil when SetDefaults) + +``` + +It's also a very useful feature for web application which default values are needed while binding request json. + +For example: +```go +type ExamplePostBody struct { + Foo *bool `json:"foo" default:"true"` //<-- StructTag with a default key + Bar *string `json:"bar" default:"example"` + Qux *int `json:"qux" default:"22"` + Oed *int64 `json:"oed" default:"64"` +} +``` + +HTTP request seems like this: +```bash +curl --location --request POST ... \ +... \ +--header 'Content-Type: application/json' \ +--data-raw '{ + "foo": false, + "bar": "", + "qux": 0 +}' +``` + +Request handler: +```go +func PostExampleHandler(c *gin.Context) { + var reqBody ExamplePostBody + if err := c.ShouldBindJSON(&reqBody); err != nil { + c.JSON(http.StatusBadRequest, nil) + return + } + defaults.SetDefaults(&reqBody) + + fmt.Println(*reqBody.Foo) //Prints: false (zero value `false` for bool but not for bool ptr) + fmt.Println(*reqBody.Bar) //Prints: "" (print "" which set in advance, not "example" for default) + fmt.Println(*reqBody.Qux) //Prints: 0 (0 instead of 22, did not confused from whether zero value is in json or not) + fmt.Println(*reqBody.Oed) //Prints: 64 (In this case "oed" is not in req json, so set default 64) + + ... +} +``` + License ------- diff --git a/defaults.go b/defaults.go index 5513b98..c42e0a5 100644 --- a/defaults.go +++ b/defaults.go @@ -12,15 +12,16 @@ import ( // the StructTag with name "default" and the directed value. // // Usage -// type ExampleBasic struct { -// Foo bool `default:"true"` -// Bar string `default:"33"` -// Qux int8 -// Dur time.Duration `default:"2m3s"` -// } // -// foo := &ExampleBasic{} -// SetDefaults(foo) +// type ExampleBasic struct { +// Foo bool `default:"true"` +// Bar string `default:"33"` +// Qux int8 +// Dur time.Duration `default:"2m3s"` +// } +// +// foo := &ExampleBasic{} +// SetDefaults(foo) func SetDefaults(variable interface{}) { getDefaultFiller().Fill(variable) } @@ -90,7 +91,11 @@ func newDefaultFiller() *Filler { types := make(map[TypeHash]FillerFunc, 1) types["time.Duration"] = func(field *FieldData) { d, _ := time.ParseDuration(field.TagValue) - field.Value.Set(reflect.ValueOf(d)) + if field.Value.Kind() == reflect.Ptr { + field.Value.Set(reflect.ValueOf(&d)) + } else { + field.Value.Set(reflect.ValueOf(d)) + } } funcs[reflect.Slice] = func(field *FieldData) { @@ -107,6 +112,16 @@ func newDefaultFiller() *Filler { fields := getDefaultFiller().GetFieldsFromValue(field.Value.Index(i), nil) getDefaultFiller().SetDefaultValues(fields) } + case reflect.Ptr: + count := field.Value.Len() + for i := 0; i < count; i++ { + if field.Value.Index(i).IsZero() { + newValue := reflect.New(field.Value.Index(i).Type().Elem()) + field.Value.Index(i).Set(newValue) + } + fields := getDefaultFiller().GetFieldsFromValue(field.Value.Index(i).Elem(), nil) + getDefaultFiller().SetDefaultValues(fields) + } default: //处理形如 [1,2,3,4] reg := regexp.MustCompile(`^\[(.*)\]$`) @@ -134,6 +149,27 @@ func newDefaultFiller() *Filler { } } + funcs[reflect.Ptr] = func(field *FieldData) { + k := field.Value.Type().Elem().Kind() + if k != reflect.Struct && field.TagValue == "" { + return + } + if field.Value.IsNil() { + v := reflect.New(field.Value.Type().Elem()) + field.Value.Set(v) + } + elemField := &FieldData{ + Value: field.Value.Elem(), + Field: reflect.StructField{ + Type: field.Field.Type.Elem(), + Tag: field.Field.Tag, + }, + TagValue: field.TagValue, + Parent: nil, + } + funcs[field.Value.Elem().Kind()](elemField) + } + return &Filler{FuncByKind: funcs, FuncByType: types, Tag: "default"} } @@ -159,13 +195,11 @@ func parseDateTimeString(data string) string { case "date": str := time.Now().AddDate(values[0], values[1], values[2]).Format("2006-01-02") data = strings.Replace(data, match[0], str, -1) - break case "time": str := time.Now().Add((time.Duration(values[0]) * time.Hour) + (time.Duration(values[1]) * time.Minute) + (time.Duration(values[2]) * time.Second)).Format("15:04:05") data = strings.Replace(data, match[0], str, -1) - break } } } diff --git a/defaults_test.go b/defaults_test.go index c86a960..35db296 100644 --- a/defaults_test.go +++ b/defaults_test.go @@ -32,6 +32,11 @@ type Child struct { Age int `default:"10"` } +type ChildPtr struct { + Name *string + Age *int `default:"10"` +} + type ExampleBasic struct { Bool bool `default:"true"` Integer int `default:"33"` @@ -61,6 +66,27 @@ type ExampleBasic struct { StringSliceSlice [][]string `default:"[[1],[]]"` DateTime string `default:"{{date:1,-10,0}} {{time:1,-5,10}}"` + + BoolPtr *bool `default:"false"` + IntPtr *int `default:"33"` + Int8Ptr *int8 `default:"8"` + Int16Ptr *int16 `default:"16"` + Int32Ptr *int32 `default:"32"` + Int64Ptr *int64 `default:"64"` + UIntPtr *uint `default:"11"` + UInt8Ptr *uint8 `default:"18"` + UInt16Ptr *uint16 `default:"116"` + UInt32Ptr *uint32 `default:"132"` + UInt64Ptr *uint64 `default:"164"` + Float32Ptr *float32 `default:"3.2"` + Float64Ptr *float64 `default:"6.4"` + DurationPtr *time.Duration `default:"1s"` + SecondPtr *time.Duration `default:"1s"` + StructPtr *struct { + Bool bool `default:"true"` + Integer *int `default:"33"` + } + ChildrenPtr []*ChildPtr } func (s *DefaultsSuite) TestSetDefaultsBasic(c *C) { @@ -106,6 +132,24 @@ func (s *DefaultsSuite) assertTypes(c *C, foo *ExampleBasic) { c.Assert(foo.IntSliceSlice, DeepEquals, [][]int{[]int{1}, []int{2}, []int{3}, []int{4}}) c.Assert(foo.StringSliceSlice, DeepEquals, [][]string{[]string{"1"}, []string{}}) c.Assert(foo.DateTime, Equals, "2020-08-10 12:55:10") + c.Assert(*foo.BoolPtr, Equals, false) + c.Assert(*foo.IntPtr, Equals, 33) + c.Assert(*foo.Int8Ptr, Equals, int8(8)) + c.Assert(*foo.Int16Ptr, Equals, int16(16)) + c.Assert(*foo.Int32Ptr, Equals, int32(32)) + c.Assert(*foo.Int64Ptr, Equals, int64(64)) + c.Assert(*foo.UIntPtr, Equals, uint(11)) + c.Assert(*foo.UInt8Ptr, Equals, uint8(18)) + c.Assert(*foo.UInt16Ptr, Equals, uint16(116)) + c.Assert(*foo.UInt32Ptr, Equals, uint32(132)) + c.Assert(*foo.UInt64Ptr, Equals, uint64(164)) + c.Assert(*foo.Float32Ptr, Equals, float32(3.2)) + c.Assert(*foo.Float64Ptr, Equals, 6.4) + c.Assert(*foo.DurationPtr, Equals, time.Second) + c.Assert(*foo.SecondPtr, Equals, time.Second) + c.Assert(foo.StructPtr.Bool, Equals, true) + c.Assert(*foo.StructPtr.Integer, Equals, 33) + c.Assert(foo.ChildrenPtr, IsNil) } func (s *DefaultsSuite) TestSetDefaultsWithValues(c *C) { @@ -118,6 +162,13 @@ func (s *DefaultsSuite) TestSetDefaultsWithValues(c *C) { Children: []Child{{Name: "alice"}, {Name: "bob", Age: 2}}, } + intzero := 0 + foo.IntPtr = &intzero + + ageZero := 0 + childPtr := &ChildPtr{Age: &ageZero} + foo.ChildrenPtr = append(foo.ChildrenPtr, childPtr) + SetDefaults(foo) c.Assert(foo.Integer, Equals, 55) @@ -127,6 +178,10 @@ func (s *DefaultsSuite) TestSetDefaultsWithValues(c *C) { c.Assert(string(foo.Bytes), Equals, "foo") c.Assert(foo.Children[0].Age, Equals, 10) c.Assert(foo.Children[1].Age, Equals, 2) + c.Assert(*foo.ChildrenPtr[0].Age, Equals, 0) + c.Assert(foo.ChildrenPtr[0].Name, IsNil) + + c.Assert(*foo.IntPtr, Equals, 0) } func (s *DefaultsSuite) BenchmarkLogic(c *C) { diff --git a/filler.go b/filler.go index abacefa..f32e2ec 100644 --- a/filler.go +++ b/filler.go @@ -82,11 +82,23 @@ func (f *Filler) isEmpty(field *FieldData) bool { // always assume the structs in the slice is empty and can be filled // the actually struct filling logic should take care of the rest return true + case reflect.Ptr: + switch field.Value.Type().Elem().Elem().Kind() { + case reflect.Struct: + return true + default: + return field.Value.Len() == 0 + } default: return field.Value.Len() == 0 } case reflect.String: return field.Value.String() == "" + case reflect.Ptr: + if field.Value.Type().Elem().Kind() == reflect.Struct { + return true + } + return field.Value.IsNil() } return true } @@ -105,12 +117,10 @@ func (f *Filler) SetDefaultValue(field *FieldData) { return } } - - return } func (f *Filler) getFunctionByName(field *FieldData) FillerFunc { - if f, ok := f.FuncByName[field.Field.Name]; ok == true { + if f, ok := f.FuncByName[field.Field.Name]; ok { return f } @@ -118,7 +128,7 @@ func (f *Filler) getFunctionByName(field *FieldData) FillerFunc { } func (f *Filler) getFunctionByType(field *FieldData) FillerFunc { - if f, ok := f.FuncByType[GetTypeHash(field.Field.Type)]; ok == true { + if f, ok := f.FuncByType[GetTypeHash(field.Field.Type)]; ok { return f } @@ -126,7 +136,7 @@ func (f *Filler) getFunctionByType(field *FieldData) FillerFunc { } func (f *Filler) getFunctionByKind(field *FieldData) FillerFunc { - if f, ok := f.FuncByKind[field.Field.Type.Kind()]; ok == true { + if f, ok := f.FuncByKind[field.Field.Type.Kind()]; ok { return f }