From 0f343dad0b52977f0d4d4785209f2684ac28a43f Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Wed, 8 Feb 2023 14:54:37 +0300 Subject: [PATCH] logger/unwrap: fix for nested tagged/untagged Signed-off-by: Vasiliy Tolstov --- logger/unwrap/unwrap.go | 73 ++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/logger/unwrap/unwrap.go b/logger/unwrap/unwrap.go index 7d219e6e..93be44d4 100644 --- a/logger/unwrap/unwrap.go +++ b/logger/unwrap/unwrap.go @@ -46,6 +46,11 @@ var ( closeMapBytes = []byte("}") ) +type protoMessage interface { + Reset() + ProtoMessage() +} + type Wrapper struct { val interface{} s fmt.State @@ -53,7 +58,7 @@ type Wrapper struct { opts *Options depth int ignoreNextType bool - takeAll map[int]bool + takeMap map[int]bool protoWrapperType bool sqlWrapperType bool } @@ -111,7 +116,7 @@ func Tagged(b bool) Option { func Unwrap(val interface{}, opts ...Option) *Wrapper { options := NewOptions(opts...) - return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeAll: make(map[int]bool)} + return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeMap: make(map[int]bool)} } func (w *Wrapper) unpackValue(v reflect.Value) reflect.Value { @@ -237,9 +242,6 @@ func (w *Wrapper) format(v reflect.Value) { _, _ = w.s.Write(buf) return } - if w.opts.Tagged { - w.checkTakeAll(v, 1) - } // Handle invalid reflect values immediately. kind := v.Kind() @@ -256,6 +258,10 @@ func (w *Wrapper) format(v reflect.Value) { w.protoWrapperType = true } else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { w.sqlWrapperType = true + } else if v.CanInterface() { + if _, ok := v.Interface().(protoMessage); ok { + w.protoWrapperType = true + } } } w.formatPtr(v) @@ -378,6 +384,12 @@ func (w *Wrapper) format(v reflect.Value) { prevSkip := false for i := 0; i < numFields; i++ { + switch vt.Field(i).Type.PkgPath() { + case "google.golang.org/protobuf/internal/impl", "google.golang.org/protobuf/internal/pragma": + w.protoWrapperType = true + prevSkip = true + continue + } if w.protoWrapperType && !vt.Field(i).IsExported() { prevSkip = true continue @@ -385,6 +397,9 @@ func (w *Wrapper) format(v reflect.Value) { prevSkip = true continue } + if _, ok := vt.Field(i).Tag.Lookup("protobuf"); ok && !w.protoWrapperType { + w.protoWrapperType = true + } sv, ok := vt.Field(i).Tag.Lookup("logger") switch { case ok: @@ -395,11 +410,16 @@ func (w *Wrapper) format(v reflect.Value) { case "take": break } - case w.takeAll[w.depth]: - break case !ok && w.opts.Tagged: - prevSkip = true - continue + // skip top level untagged + if w.depth == 1 { + prevSkip = true + continue + } + if tv, ok := w.takeMap[w.depth]; ok && !tv { + prevSkip = true + continue + } } if prevSkip { @@ -416,9 +436,7 @@ func (w *Wrapper) format(v reflect.Value) { _, _ = w.s.Write([]byte(vt.Name)) _, _ = w.s.Write(colonBytes) } - unpackValue := w.unpackValue(v.Field(i)) - w.checkTakeAll(unpackValue, w.depth) - w.format(unpackValue) + w.format(w.unpackValue(v.Field(i))) numWritten++ } w.depth-- @@ -461,6 +479,10 @@ func (w *Wrapper) Format(s fmt.State, verb rune) { return } + if w.opts.Tagged { + w.buildTakeMap(reflect.ValueOf(w.val), 1) + } + w.format(reflect.ValueOf(w.val)) } @@ -615,24 +637,28 @@ func (w *Wrapper) constructOrigFormat(verb rune) string { return buf.String() } -func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) { - if _, ok := w.takeAll[depth]; ok { - return - } +func (w *Wrapper) buildTakeMap(v reflect.Value, depth int) { if !v.IsValid() || v.IsZero() { return } + switch v.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + w.buildTakeMap(v.Index(i), depth+1) + } + w.takeMap[depth] = true + return case reflect.Struct: break case reflect.Ptr: v = v.Elem() if v.Kind() != reflect.Struct { - w.takeAll[depth] = true + w.takeMap[depth] = true return } default: - w.takeAll[depth] = true + w.takeMap[depth] = true return } @@ -641,8 +667,15 @@ func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) { for i := 0; i < v.NumField(); i++ { sv, ok := vt.Field(i).Tag.Lookup("logger") if ok && sv == "take" { - w.takeAll[depth] = false + w.takeMap[depth] = false } - w.checkTakeAll(v.Field(i), depth+1) + if v.Kind() == reflect.Struct || + (v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct) { + w.buildTakeMap(v.Field(i), depth+1) + } + } + + if _, ok := w.takeMap[depth]; !ok { + w.takeMap[depth] = true } }