From d977d7b28ee85411c0085a260e55e678abf90428 Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 15 Dec 2024 22:16:44 +0800 Subject: [PATCH] feat: support form array in three notations Signed-off-by: kevin --- core/mapping/unmarshaler.go | 66 +++++++++++++--- core/mapping/unmarshaler_test.go | 17 ++++- rest/httpx/requests_test.go | 75 +++++++++++++++++++ rest/httpx/util.go | 9 ++- tools/goctl/pkg/parser/api/parser/parser.go | 14 ++-- .../pkg/parser/api/parser/parser_test.go | 2 +- 6 files changed, 162 insertions(+), 21 deletions(-) diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index b4fb356ee319..682c51991153 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -18,6 +18,7 @@ import ( ) const ( + comma = "," defaultKeyName = "key" delimiter = '.' ignoreKey = "-" @@ -36,6 +37,7 @@ var ( defaultCacheLock sync.Mutex emptyMap = map[string]any{} emptyValue = reflect.ValueOf(lang.Placeholder) + stringSliceType = reflect.TypeOf([]string{}) ) type ( @@ -173,13 +175,18 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, baseType := fieldType.Elem() dereffedBaseType := Deref(baseType) dereffedBaseKind := dereffedBaseType.Kind() - conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) if refValue.Len() == 0 { - value.Set(conv) + value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0)) return nil } + if u.opts.fromArray { + refValue = makeStringSlice(refValue) + } + var valid bool + conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) + for i := 0; i < refValue.Len(); i++ { ithValue := refValue.Index(i).Interface() if ithValue == nil { @@ -191,17 +198,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, switch dereffedBaseKind { case reflect.Struct: - target := reflect.New(dereffedBaseType) - val, ok := ithValue.(map[string]any) - if !ok { - return errTypeMismatch - } - - if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil { + if err := u.fillStructElement(baseType, conv.Index(i), ithValue, sliceFullName); err != nil { return err } - - SetValue(fieldType.Elem(), conv.Index(i), target.Elem()) case reflect.Slice: if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil { return err @@ -310,6 +309,23 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle return u.fillSlice(derefedType, value, slice, fullName) } +func (u *Unmarshaler) fillStructElement(baseType reflect.Type, target reflect.Value, + value any, fullName string) error { + val, ok := value.(map[string]any) + if !ok { + return errTypeMismatch + } + + // use Deref(baseType) to get the base type in case the type is a pointer type. + ptr := reflect.New(Deref(baseType)) + if err := u.unmarshal(val, ptr.Interface(), fullName); err != nil { + return err + } + + SetValue(baseType, target, ptr.Elem()) + return nil +} + func (u *Unmarshaler) fillUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, targetValue string) error { if !value.CanSet() { @@ -1146,6 +1162,34 @@ func join(elem ...string) string { return builder.String() } +func makeStringSlice(refValue reflect.Value) reflect.Value { + if refValue.Len() != 1 { + return refValue + } + + element := refValue.Index(0) + if element.Kind() != reflect.String { + return refValue + } + + val, ok := element.Interface().(string) + if !ok { + return refValue + } + + splits := strings.Split(val, comma) + if len(splits) <= 1 { + return refValue + } + + slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits)) + for i, split := range splits { + slice.Index(i).Set(reflect.ValueOf(split)) + } + + return slice +} + func newInitError(name string) error { return fmt.Errorf("field %q is not set", name) } diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 3270632dc447..944a575d851f 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -351,7 +351,7 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) { assert.Error(t, UnmarshalKey(m, &in)) }) - t.Run("int slice with nil", func(t *testing.T) { + t.Run("int slice with nil element", func(t *testing.T) { type inner struct { Ints []int `key:"ints"` } @@ -365,6 +365,21 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) { assert.Empty(t, in.Ints) } }) + + t.Run("int slice with nil", func(t *testing.T) { + type inner struct { + Ints []int `key:"ints"` + } + + m := map[string]any{ + "ints": []any(nil), + } + + var in inner + if assert.NoError(t, UnmarshalKey(m, &in)) { + assert.Empty(t, in.Ints) + } + }) } func TestUnmarshalIntWithDefault(t *testing.T) { diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 437b9a136fb5..f543292d8735 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -88,6 +88,21 @@ func TestParseFormArray(t *testing.T) { } }) + t.Run("slice with empty", func(t *testing.T) { + var v struct { + Name []string `form:"name,optional"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?name=", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{}, v.Name) + } + }) + t.Run("slice with empty and non-empty", func(t *testing.T) { var v struct { Name []string `form:"name"` @@ -102,6 +117,66 @@ func TestParseFormArray(t *testing.T) { assert.ElementsMatch(t, []string{"1"}, v.Name) } }) + + t.Run("slice with one value on array format", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names=1,2,3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names) + } + }) + + t.Run("slice with one value on combined array format", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names=[1,2,3]&names=4", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"[1,2,3]", "4"}, v.Names) + } + }) + + t.Run("slice with one value on integer array format", func(t *testing.T) { + var v struct { + Numbers []int `form:"numbers"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?numbers=1,2,3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []int{1, 2, 3}, v.Numbers) + } + }) + + t.Run("slice with one value on array format brackets", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names[]=1&names[]=2&names[]=3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names) + } + }) } func TestParseForm_Error(t *testing.T) { diff --git a/rest/httpx/util.go b/rest/httpx/util.go index 19248ae74bb3..5540f53e83a7 100644 --- a/rest/httpx/util.go +++ b/rest/httpx/util.go @@ -3,9 +3,13 @@ package httpx import ( "errors" "net/http" + "strings" ) -const xForwardedFor = "X-Forwarded-For" +const ( + xForwardedFor = "X-Forwarded-For" + arraySuffix = "[]" +) // GetFormValues returns the form values. func GetFormValues(r *http.Request) (map[string]any, error) { @@ -29,6 +33,9 @@ func GetFormValues(r *http.Request) (map[string]any, error) { } if len(filtered) > 0 { + if strings.HasSuffix(name, arraySuffix) { + name = name[:len(name)-2] + } params[name] = filtered } } diff --git a/tools/goctl/pkg/parser/api/parser/parser.go b/tools/goctl/pkg/parser/api/parser/parser.go index 9faa00026b19..d4c267d6c1e7 100644 --- a/tools/goctl/pkg/parser/api/parser/parser.go +++ b/tools/goctl/pkg/parser/api/parser/parser.go @@ -13,13 +13,13 @@ import ( ) const ( - idAPI = "api" - groupKeyText = "group" - infoTitleKey = "Title" - infoDescKey = "Desc" - infoVersionKey = "Version" - infoAuthorKey = "Author" - infoEmailKey = "Email" + idAPI = "api" + groupKeyText = "group" + infoTitleKey = "Title" + infoDescKey = "Desc" + infoVersionKey = "Version" + infoAuthorKey = "Author" + infoEmailKey = "Email" ) // Parser is the parser for api file. diff --git a/tools/goctl/pkg/parser/api/parser/parser_test.go b/tools/goctl/pkg/parser/api/parser/parser_test.go index 361903ee8eb5..adc7b3acfa26 100644 --- a/tools/goctl/pkg/parser/api/parser/parser_test.go +++ b/tools/goctl/pkg/parser/api/parser/parser_test.go @@ -305,7 +305,7 @@ func TestParser_Parse_atServerStmt(t *testing.T) { "prefix3:": "v1/v2_", "prefix4:": "a-b-c", "summary:": `"test"`, - "key:": `"bar"`, + "key:": `"bar"`, } p := New("foo.api", atServerTestAPI)