From fa06de9d5cc6c50584f4eae7c7fdeca021e797c0 Mon Sep 17 00:00:00 2001 From: chenwenjie Date: Thu, 4 Feb 2021 14:51:14 +0800 Subject: [PATCH] a-support struct field declare as ptr type ps: type St struct { Id *int Name *string } --- go.sum | 1 + parser_xml_test.go | 1 + proc_res.go | 47 +++++++++++---- val.go | 146 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 184 insertions(+), 11 deletions(-) diff --git a/go.sum b/go.sum index 25b74ef..0b9e693 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parser_xml_test.go b/parser_xml_test.go index 891466c..55f52d2 100644 --- a/parser_xml_test.go +++ b/parser_xml_test.go @@ -1,6 +1,7 @@ package gobatis import ( + "github.com/stretchr/testify/assert" "strings" "testing" ) diff --git a/proc_res.go b/proc_res.go index 4e6b6fe..66ebe15 100644 --- a/proc_res.go +++ b/proc_res.go @@ -275,7 +275,12 @@ func rowsToMaps(rows *sql.Rows) ([]interface{}, error) { scanArgs[i] = &vals[i] } - rows.Scan(scanArgs...) + err = rows.Scan(scanArgs...) + if nil != err { + LOG.Error("rows scan err:%v", err) + return nil, err + } + for i := 0; i < len(cols); i++ { val := vals[i] if nil != val { @@ -309,7 +314,12 @@ func rowsToSlices(rows *sql.Rows) ([]interface{}, error) { scanArgs[i] = &vals[i] } - rows.Scan(scanArgs...) + err = rows.Scan(scanArgs...) + if nil != err { + LOG.Error("rows scan err:%v", err) + return nil, err + } + for i := 0; i < len(cols); i++ { val := vals[i] if nil != val { @@ -367,18 +377,33 @@ func rowsToStructs(rows *sql.Rows, resultType reflect.Type) ([]interface{}, erro // 设置相关字段的值,并判断是否可设值 if field.CanSet() && vals[i] != nil { //获取字段类型并设值 - data := dataToFieldVal(vals[i], field.Type(), fieldName) - - // 数据库返回类型与字段类型不符合的情况下通知用户 - if reflect.TypeOf(data).Name() != field.Type().Name() { - warnInfo := "[WARN] fieldType != dataType, filedName:" + fieldName + - " fieldType:" + field.Type().Name() + - " dataType:" + reflect.TypeOf(data).Name() - LOG.Warn(warnInfo) + ft := field.Type() + isPtr := false + if ft.Kind() == reflect.Ptr { + isPtr = true + ft = ft.Elem() } + data := dataToFieldVal(vals[i], ft, fieldName) if nil != data { - field.Set(reflect.ValueOf(data)) + // 数据库返回类型与字段类型不符合的情况下提醒用户 + dt := reflect.TypeOf(data) + if dt.Name() != ft.Name() { + warnInfo := "[WARN] fieldType != dataType, filedName:" + fieldName + + " fieldType:" + ft.Name() + + " dataType:" + dt.Name() + LOG.Warn(warnInfo) + } + + if isPtr { + data = dataToPtr(data, ft, fieldName) + val := reflect.ValueOf(data) + field.Set(val) + } else { + val := reflect.ValueOf(data) + field.Set(val) + } + } } } diff --git a/val.go b/val.go index 5d097f9..3d413c2 100644 --- a/val.go +++ b/val.go @@ -617,6 +617,152 @@ func valUpcast(data interface{}, typeName string) interface{} { return d } +func dataToPtr(data interface{}, tp reflect.Type, fieldName string) interface{} { + defer func() { + if err := recover(); nil != err { + LOG.Warn("[WARN] data to field val panic, fieldName:", fieldName, " err:", err) + } + }() + + typeName := tp.Name() + switch { + case typeName == "bool": + d := data.(bool) + data = &d + case typeName == "int": + d := data.(int) + data = &d + case typeName == "int8": + d := data.(int8) + data = &d + case typeName == "int16": + d := data.(int16) + data = &d + case typeName == "int32": + d := data.(int32) + data = &d + case typeName == "int64": + d := data.(int64) + data = &d + case typeName == "uint": + d := data.(uint) + data = &d + case typeName == "uint8": + d := data.(uint8) + data = &d + case typeName == "uint16": + d := data.(uint16) + data = &d + case typeName == "uint32": + d := data.(uint32) + data = &d + case typeName == "uint64": + d := data.(uint64) + data = &d + case typeName == "uintptr": + d := data.(uintptr) + data = &d + case typeName == "float32": + d := data.(float32) + data = &d + case typeName == "float64": + d := data.(float64) + data = &d + case typeName == "complex64": + d := data.(complex64) + data = &d + case typeName == "complex128": + d := data.(complex128) + data = &d + case typeName == "string": + d := data.(string) + data = &d + case typeName == "Time": + d := data.(time.Time) + data = &d + case typeName == "NullString": + if nil != data { + if reflect.TypeOf(data).Kind() == reflect.Slice || + reflect.TypeOf(data).Kind() == reflect.Array { + data = string(data.([]byte)) + } else { + data = valToString(data) + } + data = &NullString{String: data.(string), Valid: true} + } + case typeName == "NullInt64": + if nil != data { + if reflect.TypeOf(data).Kind() == reflect.Slice || + reflect.TypeOf(data).Kind() == reflect.Array { + data = string(data.([]byte)) + } else { + data = valToString(data) + } + + i, err := strconv.ParseInt(data.(string), 10, 64) + if err != nil { + panic("ParseInt err:" + err.Error()) + } + data = &NullInt64{Int64: i, Valid: true} + } + case typeName == "NullBool": + if nil != data { + if reflect.TypeOf(data).Kind() == reflect.Slice || + reflect.TypeOf(data).Kind() == reflect.Array { + data = string(data.([]byte)) + } else { + data = valToString(data) + } + if data.(string) == "true" { + return NullBool{Bool: true, Valid: true} + } + data = &NullBool{Bool: false, Valid: true} + } + case typeName == "NullFloat64": + if nil != data { + if reflect.TypeOf(data).Kind() == reflect.Slice || + reflect.TypeOf(data).Kind() == reflect.Array { + data = string(data.([]byte)) + } else { + data = valToString(data) + } + + f64, err := strconv.ParseFloat(data.(string), 64) + if err != nil { + panic("ParseFloat err:" + err.Error()) + } + + data = &NullFloat64{Float64: f64, Valid: true} + } + case typeName == "NullTime": + if nil != data { + var t time.Time + dt, ok := data.(time.Time) + if !ok { + if reflect.TypeOf(data).Kind() == reflect.Slice || + reflect.TypeOf(data).Kind() == reflect.Array { + data = string(data.([]byte)) + } else { + data = valToString(data) + } + + tt, err := time.Parse("2006-01-02 15:04:05", data.(string)) + if err != nil { + panic("time.Parse err:" + err.Error()) + } + + t = tt + } else { + t = dt + } + + data = &NullTime{Time: t, Valid: true} + } + } + + return data +} + func dataToFieldVal(data interface{}, tp reflect.Type, fieldName string) interface{} { defer func() { if err := recover(); nil != err {