Skip to content

Commit

Permalink
feat: support form array in three notations (#4498)
Browse files Browse the repository at this point in the history
Signed-off-by: kevin <[email protected]>
  • Loading branch information
kevwan authored Dec 22, 2024
1 parent 2159d11 commit 1d9159e
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 81 deletions.
145 changes: 100 additions & 45 deletions core/mapping/unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
)

const (
comma = ","
defaultKeyName = "key"
delimiter = '.'
ignoreKey = "-"
Expand All @@ -36,6 +37,7 @@ var (
defaultCacheLock sync.Mutex
emptyMap = map[string]any{}
emptyValue = reflect.ValueOf(lang.Placeholder)
stringSliceType = reflect.TypeOf([]string{})
)

type (
Expand Down Expand Up @@ -80,40 +82,11 @@ func (u *Unmarshaler) Unmarshal(i, v any) error {
return u.unmarshal(i, v, "")
}

func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
valueType := reflect.TypeOf(v)
if valueType.Kind() != reflect.Ptr {
return errValueNotSettable
}

elemType := Deref(valueType)
switch iv := i.(type) {
case map[string]any:
if elemType.Kind() != reflect.Struct {
return errTypeMismatch
}

return u.unmarshalValuer(mapValuer(iv), v, fullName)
case []any:
if elemType.Kind() != reflect.Slice {
return errTypeMismatch
}

return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
default:
return errUnsupportedType
}
}

// UnmarshalValuer unmarshals m into v.
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error {
return u.unmarshalValuer(simpleValuer{current: m}, v, "")
}

func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
}

func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string) error {
if !value.CanSet() {
Expand Down Expand Up @@ -173,13 +146,18 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
baseType := fieldType.Elem()
dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind()
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
if refValue.Len() == 0 {
value.Set(conv)
value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
return nil
}

if u.opts.fromArray {
refValue = makeStringSlice(refValue)
}

var valid bool
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())

for i := 0; i < refValue.Len(); i++ {
ithValue := refValue.Index(i).Interface()
if ithValue == nil {
Expand All @@ -191,17 +169,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,

switch dereffedBaseKind {
case reflect.Struct:
target := reflect.New(dereffedBaseType)
val, ok := ithValue.(map[string]any)
if !ok {
return errTypeMismatch
}

if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil {
if err := u.fillStructElement(baseType, conv.Index(i), ithValue, sliceFullName); err != nil {
return err
}

SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
case reflect.Slice:
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil {
return err
Expand Down Expand Up @@ -236,7 +206,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
return errUnsupportedType
}

baseFieldType := Deref(fieldType.Elem())
baseFieldType := fieldType.Elem()
baseFieldKind := baseFieldType.Kind()
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))

Expand All @@ -257,29 +227,39 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
}

ithVal := slice.Index(index)
ithValType := ithVal.Type()

switch v := value.(type) {
case fmt.Stringer:
return setValueFromString(baseKind, ithVal, v.String())
case string:
return setValueFromString(baseKind, ithVal, v)
case map[string]any:
return u.fillMap(ithVal.Type(), ithVal, value, fullName)
// deref to handle both pointer and non-pointer types.
switch Deref(ithValType).Kind() {
case reflect.Struct:
return u.fillStructElement(ithValType, ithVal, v, fullName)
case reflect.Map:
return u.fillMap(ithValType, ithVal, value, fullName)
default:
return errTypeMismatch
}
default:
// don't need to consider the difference between int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
if ithVal.Kind() == reflect.Ptr {
baseType := Deref(ithVal.Type())
baseType := Deref(ithValType)
if !reflect.TypeOf(value).AssignableTo(baseType) {
return errTypeMismatch
}

target := reflect.New(baseType).Elem()
target.Set(reflect.ValueOf(value))
SetValue(ithVal.Type(), ithVal, target)
SetValue(ithValType, ithVal, target)
return nil
}

if !reflect.TypeOf(value).AssignableTo(ithVal.Type()) {
if !reflect.TypeOf(value).AssignableTo(ithValType) {
return errTypeMismatch
}

Expand Down Expand Up @@ -310,6 +290,23 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle
return u.fillSlice(derefedType, value, slice, fullName)
}

func (u *Unmarshaler) fillStructElement(baseType reflect.Type, target reflect.Value,
value any, fullName string) error {
val, ok := value.(map[string]any)
if !ok {
return errTypeMismatch
}

// use Deref(baseType) to get the base type in case the type is a pointer type.
ptr := reflect.New(Deref(baseType))
if err := u.unmarshal(val, ptr.Interface(), fullName); err != nil {
return err
}

SetValue(baseType, target, ptr.Elem())
return nil
}

func (u *Unmarshaler) fillUnmarshalerStruct(fieldType reflect.Type,
value reflect.Value, targetValue string) error {
if !value.CanSet() {
Expand Down Expand Up @@ -952,6 +949,35 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
return nil
}

func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
valueType := reflect.TypeOf(v)
if valueType.Kind() != reflect.Ptr {
return errValueNotSettable
}

elemType := Deref(valueType)
switch iv := i.(type) {
case map[string]any:
if elemType.Kind() != reflect.Struct {
return errTypeMismatch
}

return u.unmarshalValuer(mapValuer(iv), v, fullName)
case []any:
if elemType.Kind() != reflect.Slice {
return errTypeMismatch
}

return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
default:
return errUnsupportedType
}
}

func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
}

func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error {
rv := reflect.ValueOf(v)
if err := ValidatePtr(rv); err != nil {
Expand Down Expand Up @@ -1146,6 +1172,35 @@ func join(elem ...string) string {
return builder.String()
}

func makeStringSlice(refValue reflect.Value) reflect.Value {
if refValue.Len() != 1 {
return refValue
}

element := refValue.Index(0)
if element.Kind() != reflect.String {
return refValue
}

val, ok := element.Interface().(string)
if !ok {
return refValue
}

splits := strings.Split(val, comma)
if len(splits) <= 1 {
return refValue
}

slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits))
for i, split := range splits {
// allow empty strings
slice.Index(i).Set(reflect.ValueOf(split))
}

return slice
}

func newInitError(name string) error {
return fmt.Errorf("field %q is not set", name)
}
Expand Down
Loading

0 comments on commit 1d9159e

Please sign in to comment.