Skip to content

Commit

Permalink
copier支持类型转换
Browse files Browse the repository at this point in the history
  • Loading branch information
hookokoko committed Jul 26, 2023
1 parent 4a29bd6 commit 285fe77
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 15 deletions.
28 changes: 28 additions & 0 deletions bean/copier/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package copier

import (
"fmt"
"time"
)

type Converter interface {
Convert(src any) (any, error)
}

type ConverterFunc func(src any) (any, error)

func (cf ConverterFunc) Convert(src any) (any, error) {
return cf(src)
}

type Time2String struct {
Pattern string
}

func (t Time2String) Convert(src any) (any, error) {
tm, ok := src.(time.Time)
if !ok {
return nil, fmt.Errorf("convert type is not time")
}
return tm.Format(t.Pattern), nil
}
14 changes: 14 additions & 0 deletions bean/copier/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Copier[Src any, Dst any] interface {
type options struct {
// ignoreFields 执行复制操作时,需要忽略的字段
ignoreFields *set.MapSet[string]
// convertFields 执行转换的field和转化接口
convertFields map[string]Converter
}

func newOptions() *options {
Expand Down Expand Up @@ -65,3 +67,15 @@ func IgnoreFields(fields ...string) option.Option[options] {
}
}
}

func ConvertField(field string, converter Converter) option.Option[options] {
return func(opt *options) {
if field == "" || converter == nil {
return
}
if opt.convertFields == nil {
opt.convertFields = make(map[string]Converter, 16)
}
opt.convertFields[field] = converter
}
}
39 changes: 31 additions & 8 deletions bean/copier/reflect_copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package copier

import (
"reflect"
"time"

"github.com/ecodeclub/ekit/bean/option"
)
Expand Down Expand Up @@ -50,7 +51,7 @@ type fieldNode struct {
}

// NewReflectCopier 如果类型不匹配, 创建时直接检查报错.
func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
func NewReflectCopier[Src any, Dst any](opts ...option.Option[options]) (*ReflectCopier[Src, Dst], error) {
src := new(Src)
srcTyp := reflect.TypeOf(src).Elem()
dst := new(Dst)
Expand All @@ -72,6 +73,11 @@ func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
copier := &ReflectCopier[Src, Dst]{
rootField: root,
}

opt := newOptions()
option.Apply(opt, opts...)
copier.options = opt

return copier, nil
}

Expand All @@ -98,9 +104,6 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
continue
}
srcFieldTypStruct := srcTyp.Field(srcIndex)
if srcFieldTypStruct.Type.Kind() != dstFieldTypStruct.Type.Kind() {
return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name)
}

if srcFieldTypStruct.Type.Kind() == reflect.Pointer {
if srcFieldTypStruct.Type.Elem().Kind() != dstFieldTypStruct.Type.Elem().Kind() {
Expand Down Expand Up @@ -133,6 +136,8 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
}
// 说明当前节点是叶子节点, 直接拷贝
child.isLeaf = true
} else if fieldSrcTyp == reflect.TypeOf(time.Time{}) {
child.isLeaf = true
} else if fieldSrcTyp.Kind() == reflect.Struct {
if err := createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil {
return err
Expand Down Expand Up @@ -160,9 +165,13 @@ func (r *ReflectCopier[Src, Dst]) Copy(src *Src, opts ...option.Option[options])
// 3. 如果 Src 和 Dst 中匹配的字段,其类型都是结构体,或者都是结构体指针,则会深入复制
// 4. 否则,忽略字段
func (r *ReflectCopier[Src, Dst]) CopyTo(src *Src, dst *Dst, opts ...option.Option[options]) error {
opt := newOptions()
option.Apply(opt, opts...)
r.options = opt
if r.options == nil {
opt := newOptions()
option.Apply(opt, opts...)
r.options = opt
} else {
option.Apply(r.options, opts...)
}

return r.copyToWithTree(src, dst)
}
Expand Down Expand Up @@ -192,9 +201,23 @@ func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue ref
}
// 执行拷贝
if root.isLeaf {
if dstValue.CanSet() {
convert, ok := r.options.convertFields[root.name]
if !ok && srcTyp.Kind() != dstType.Kind() {
return nil
}
if !ok && srcTyp.Kind() == dstType.Kind() && dstValue.CanSet() {
dstValue.Set(srcValue)
return nil
}

convSrcVal, err := convert.Convert(srcValue.Interface())
if err != nil {
return err
}
if dstValue.CanSet() {
dstValue.Set(reflect.ValueOf(convSrcVal))
}

return nil
}

Expand Down
87 changes: 80 additions & 7 deletions bean/copier/reflect_copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
package copier

import (
"fmt"
"reflect"
"testing"
"time"

"github.com/ecodeclub/ekit"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -267,7 +269,7 @@ func TestReflectCopier_Copy(t *testing.T) {
S: struct{ A string }{A: "a"},
})
},
wantErr: newErrKindNotMatchError(reflect.String, reflect.Int, "A"),
wantErr: newErrTypeNotMatchError(reflect.TypeOf("a"), reflect.TypeOf(0), "A"),
},
{
name: "多重指针",
Expand Down Expand Up @@ -1040,6 +1042,75 @@ func TestReflectCopier_Copy(t *testing.T) {
},
},
},
{
name: "指定convert time2string",
copyFunc: func() (any, error) {
copier, err := NewReflectCopier[SimpleSrc, SimpleDst]()
if err != nil {
return nil, err
}
return copier.Copy(&SimpleSrc{
Name: "大明",
BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC),
Friends: []string{"Tom", "Jerry"},
}, ConvertField("BirthDay", Time2String{Pattern: "2006-01-02 15:04:05"}))
},
wantDst: &SimpleDst{
Name: "大明",
BirthDay: "2023-07-26 09:15:22",
Friends: []string{"Tom", "Jerry"},
},
},
{
name: "指定convert func",
copyFunc: func() (any, error) {
copier, err := NewReflectCopier[SimpleSrc, SimpleDst]()
if err != nil {
return nil, err
}
return copier.Copy(&SimpleSrc{
Name: "大明",
Friends: []string{"Tom", "Jerry"},
}, ConvertField("Name", ConverterFunc(func(src any) (any, error) {
var newS string
s, ok := src.(string)
if ok {
newS = fmt.Sprintf("%s plus", s)
}
return newS, nil
})))
},
wantDst: &SimpleDst{
Name: "大明 plus",
Friends: []string{"Tom", "Jerry"},
},
},
{
name: "创建时指定默认converter",
copyFunc: func() (any, error) {
copier, err := NewReflectCopier[SimpleSrc, SimpleDst](ConvertField("BirthDay", Time2String{Pattern: "2006-01-02 15:04:05"}))
if err != nil {
return nil, err
}
return copier.Copy(&SimpleSrc{
Name: "大明",
BirthDay: time.Date(2023, time.July, 26, 9, 15, 22, 213, time.UTC),
Friends: []string{"Tom", "Jerry"},
}, ConvertField("Name", ConverterFunc(func(src any) (any, error) {
var newS string
s, ok := src.(string)
if ok {
newS = fmt.Sprintf("%s plus", s)
}
return newS, nil
})))
},
wantDst: &SimpleDst{
Name: "大明 plus",
BirthDay: "2023-07-26 09:15:22",
Friends: []string{"Tom", "Jerry"},
},
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -1067,15 +1138,17 @@ type BasicDst struct {
}

type SimpleSrc struct {
Name string
Age *int
Friends []string
Name string
Age *int
BirthDay time.Time
Friends []string
}

type SimpleDst struct {
Name string
Age *int
Friends []string
Name string
Age *int
BirthDay string
Friends []string
}

type EmbedSrc struct {
Expand Down

0 comments on commit 285fe77

Please sign in to comment.