Skip to content

Commit

Permalink
Merge pull request #56 from wenj91/dev
Browse files Browse the repository at this point in the history
a-support struct field declare as ptr type
  • Loading branch information
wenj91 authored Feb 4, 2021
2 parents 741a89c + fa06de9 commit 7c67715
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 11 deletions.
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
1 change: 1 addition & 0 deletions parser_xml_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gobatis

import (
"github.com/stretchr/testify/assert"
"strings"
"testing"
)
Expand Down
47 changes: 36 additions & 11 deletions proc_res.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ func rowsToMaps(rows *sql.Rows) ([]interface{}, error) {
scanArgs[i] = &vals[i]
}

rows.Scan(scanArgs...)
err = rows.Scan(scanArgs...)
if nil != err {
LOG.Error("rows scan err:%v", err)
return nil, err
}

for i := 0; i < len(cols); i++ {
val := vals[i]
if nil != val {
Expand Down Expand Up @@ -309,7 +314,12 @@ func rowsToSlices(rows *sql.Rows) ([]interface{}, error) {
scanArgs[i] = &vals[i]
}

rows.Scan(scanArgs...)
err = rows.Scan(scanArgs...)
if nil != err {
LOG.Error("rows scan err:%v", err)
return nil, err
}

for i := 0; i < len(cols); i++ {
val := vals[i]
if nil != val {
Expand Down Expand Up @@ -367,18 +377,33 @@ func rowsToStructs(rows *sql.Rows, resultType reflect.Type) ([]interface{}, erro
// 设置相关字段的值,并判断是否可设值
if field.CanSet() && vals[i] != nil {
//获取字段类型并设值
data := dataToFieldVal(vals[i], field.Type(), fieldName)

// 数据库返回类型与字段类型不符合的情况下通知用户
if reflect.TypeOf(data).Name() != field.Type().Name() {
warnInfo := "[WARN] fieldType != dataType, filedName:" + fieldName +
" fieldType:" + field.Type().Name() +
" dataType:" + reflect.TypeOf(data).Name()
LOG.Warn(warnInfo)
ft := field.Type()
isPtr := false
if ft.Kind() == reflect.Ptr {
isPtr = true
ft = ft.Elem()
}

data := dataToFieldVal(vals[i], ft, fieldName)
if nil != data {
field.Set(reflect.ValueOf(data))
// 数据库返回类型与字段类型不符合的情况下提醒用户
dt := reflect.TypeOf(data)
if dt.Name() != ft.Name() {
warnInfo := "[WARN] fieldType != dataType, filedName:" + fieldName +
" fieldType:" + ft.Name() +
" dataType:" + dt.Name()
LOG.Warn(warnInfo)
}

if isPtr {
data = dataToPtr(data, ft, fieldName)
val := reflect.ValueOf(data)
field.Set(val)
} else {
val := reflect.ValueOf(data)
field.Set(val)
}

}
}
}
Expand Down
146 changes: 146 additions & 0 deletions val.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,152 @@ func valUpcast(data interface{}, typeName string) interface{} {
return d
}

func dataToPtr(data interface{}, tp reflect.Type, fieldName string) interface{} {
defer func() {
if err := recover(); nil != err {
LOG.Warn("[WARN] data to field val panic, fieldName:", fieldName, " err:", err)
}
}()

typeName := tp.Name()
switch {
case typeName == "bool":
d := data.(bool)
data = &d
case typeName == "int":
d := data.(int)
data = &d
case typeName == "int8":
d := data.(int8)
data = &d
case typeName == "int16":
d := data.(int16)
data = &d
case typeName == "int32":
d := data.(int32)
data = &d
case typeName == "int64":
d := data.(int64)
data = &d
case typeName == "uint":
d := data.(uint)
data = &d
case typeName == "uint8":
d := data.(uint8)
data = &d
case typeName == "uint16":
d := data.(uint16)
data = &d
case typeName == "uint32":
d := data.(uint32)
data = &d
case typeName == "uint64":
d := data.(uint64)
data = &d
case typeName == "uintptr":
d := data.(uintptr)
data = &d
case typeName == "float32":
d := data.(float32)
data = &d
case typeName == "float64":
d := data.(float64)
data = &d
case typeName == "complex64":
d := data.(complex64)
data = &d
case typeName == "complex128":
d := data.(complex128)
data = &d
case typeName == "string":
d := data.(string)
data = &d
case typeName == "Time":
d := data.(time.Time)
data = &d
case typeName == "NullString":
if nil != data {
if reflect.TypeOf(data).Kind() == reflect.Slice ||
reflect.TypeOf(data).Kind() == reflect.Array {
data = string(data.([]byte))
} else {
data = valToString(data)
}
data = &NullString{String: data.(string), Valid: true}
}
case typeName == "NullInt64":
if nil != data {
if reflect.TypeOf(data).Kind() == reflect.Slice ||
reflect.TypeOf(data).Kind() == reflect.Array {
data = string(data.([]byte))
} else {
data = valToString(data)
}

i, err := strconv.ParseInt(data.(string), 10, 64)
if err != nil {
panic("ParseInt err:" + err.Error())
}
data = &NullInt64{Int64: i, Valid: true}
}
case typeName == "NullBool":
if nil != data {
if reflect.TypeOf(data).Kind() == reflect.Slice ||
reflect.TypeOf(data).Kind() == reflect.Array {
data = string(data.([]byte))
} else {
data = valToString(data)
}
if data.(string) == "true" {
return NullBool{Bool: true, Valid: true}
}
data = &NullBool{Bool: false, Valid: true}
}
case typeName == "NullFloat64":
if nil != data {
if reflect.TypeOf(data).Kind() == reflect.Slice ||
reflect.TypeOf(data).Kind() == reflect.Array {
data = string(data.([]byte))
} else {
data = valToString(data)
}

f64, err := strconv.ParseFloat(data.(string), 64)
if err != nil {
panic("ParseFloat err:" + err.Error())
}

data = &NullFloat64{Float64: f64, Valid: true}
}
case typeName == "NullTime":
if nil != data {
var t time.Time
dt, ok := data.(time.Time)
if !ok {
if reflect.TypeOf(data).Kind() == reflect.Slice ||
reflect.TypeOf(data).Kind() == reflect.Array {
data = string(data.([]byte))
} else {
data = valToString(data)
}

tt, err := time.Parse("2006-01-02 15:04:05", data.(string))
if err != nil {
panic("time.Parse err:" + err.Error())
}

t = tt
} else {
t = dt
}

data = &NullTime{Time: t, Valid: true}
}
}

return data
}

func dataToFieldVal(data interface{}, tp reflect.Type, fieldName string) interface{} {
defer func() {
if err := recover(); nil != err {
Expand Down

0 comments on commit 7c67715

Please sign in to comment.