diff --git a/assertion/function/assertiontree/backprop.go b/assertion/function/assertiontree/backprop.go index d4022f64..9383d866 100644 --- a/assertion/function/assertiontree/backprop.go +++ b/assertion/function/assertiontree/backprop.go @@ -27,6 +27,7 @@ import ( "go.uber.org/nilaway/annotation" "go.uber.org/nilaway/config" "go.uber.org/nilaway/util" + "go.uber.org/nilaway/util/asthelper" "golang.org/x/exp/slices" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/cfg" @@ -590,9 +591,18 @@ buildShadowMask: if ok && lhsNode != nil { // Add assignment entries to the consumers of lhsNode for informative printing of errors for _, c := range lhsNode.ConsumeTriggers() { + var lhsExprStr, rhsExprStr string + var err error + if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + c.Annotation.AddAssignment(annotation.Assignment{ - LHSExprStr: util.ExprToString(lhsVal, rootNode.Pass()), - RHSExprStr: util.ExprToString(rhsVal, rootNode.Pass()), + LHSExprStr: lhsExprStr, + RHSExprStr: rhsExprStr, Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())), }) } @@ -635,9 +645,18 @@ buildShadowMask: continue } for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] { + var lhsExprStr, rhsExprStr string + var err error + if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + t.Consumer.Annotation.AddAssignment(annotation.Assignment{ - LHSExprStr: util.ExprToString(lhsVal, rootNode.Pass()), - RHSExprStr: util.ExprToString(rhsVal, rootNode.Pass()), + LHSExprStr: lhsExprStr, + RHSExprStr: rhsExprStr, Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())), }) } diff --git a/testdata/src/go.uber.org/errormessage/errormessage.go b/testdata/src/go.uber.org/errormessage/errormessage.go index 40fa7dc3..d424d7f0 100644 --- a/testdata/src/go.uber.org/errormessage/errormessage.go +++ b/testdata/src/go.uber.org/errormessage/errormessage.go @@ -100,6 +100,7 @@ func test9(m map[int]*int) { y := x print(*y) //want "`m\\[0\\]` to `x`" } + func test10(ch chan *int) { x := <-ch //want "nil channel accessed" y := x @@ -147,3 +148,117 @@ func test13() *int { } return new(int) } + +// below tests check shortening of expressions in assignment messages + +// nilable(s, result 0) +func (s *S) bar(i int) *int { + return nil +} + +// nilable(result 0) +func (s *S) foo(a int, b *int, c string, d bool) *S { + return nil +} + +func test14(x *int, i int) { + s := &S{} + x = s.foo(1, + new(int), + "abc", + true).bar(i) + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(i\\)` to `x`" +} + +func test15(x *int) { + var longVarName, anotherLongVarName, yetAnotherLongName int + s := &S{} + x = s.foo(longVarName, &anotherLongVarName, "abc", true).bar(yetAnotherLongName) + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(...\\)` to `x`" +} + +func test16(mp map[int]*int) { + var aVeryVeryVeryLongIndexVar int + x := mp[aVeryVeryVeryLongIndexVar] + y := x + print(*y) //want "`mp\\[...\\]` to `x`" +} + +func test17(x *int, mp map[int]*int) { + var aVeryVeryVeryLongIndexVar int + s := &S{} + + x = s.foo(1, mp[aVeryVeryVeryLongIndexVar], "abc", true).bar(2) //want "deep read" + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(2\\)` to `x`" +} + +func test18(x *int, mp map[int]*int) { + s := &S{} + x = mp[*(s.foo(1, new(int), "abc", true).bar(2))] //want "dereferenced" + y := x + print(*y) //want "`mp\\[...\\]` to `x`" +} + +func test19() { + mp := make(map[string]*string) + x := mp["("] + y := x + print(*y) //want "`mp\\[\"\\(\"\\]` to `x`" + + x = mp[")"] + y = x + print(*y) //want "`mp\\[\"\\)\"\\]` to `x`" + + x = mp["))"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["(("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp[")))((("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp[")))((("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["(((()"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["())))"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + s := &S{} + i := 0 + a := s.foo(1, + new(int), + "({[", + true).bar(i) + b := a + print(*b) //want "`s.foo\\(...\\).bar\\(i\\)` to `a`" +} + +func test20() { + mp := make(map[rune]*rune) + x := mp['('] + y := x + print(*y) //want "`mp\\['\\('\\]` to `x`" + + x = mp[')'] + y = x + print(*y) //want "`mp\\['\\)'\\]` to `x`" +} + +// below test checks that NilAway can handle non-English (non-ASCII) identifiers +func test21() { + var 世界 *int = nil + print(*世界) //want "`nil` to `世界`" +} diff --git a/util/asthelper/asthelper.go b/util/asthelper/asthelper.go new file mode 100644 index 00000000..7d757e57 --- /dev/null +++ b/util/asthelper/asthelper.go @@ -0,0 +1,104 @@ +// Copyright (c) 2023 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package asthelper implements utility functions for AST. +package asthelper + +import ( + "go/ast" + "go/printer" + "go/token" + "io" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// PrintExpr converts AST expression to string, and shortens long expressions if isShortenExpr is true +func PrintExpr(e ast.Expr, pass *analysis.Pass, isShortenExpr bool) (string, error) { + builder := &strings.Builder{} + var err error + + if !isShortenExpr { + err = printer.Fprint(builder, pass.Fset, e) + } else { + // traverse over the AST expression's subtree and shorten long expressions + // (e.g., s.foo(longVarName, anotherLongVarName, someOtherLongVarName) --> s.foo(...)) + err = printExpr(builder, pass.Fset, e) + } + + return builder.String(), err +} + +func printExpr(writer io.Writer, fset *token.FileSet, e ast.Expr) (err error) { + // _shortenExprLen is the maximum length of an expression to be printed in full. The value is set to 3 to account for + // the length of the ellipsis ("..."), which is used to shorten long expressions. + const _shortenExprLen = 3 + + // fullExpr returns true if the expression is short enough (<= _shortenExprLen) to be printed in full + fullExpr := func(node ast.Node) (string, bool) { + switch n := node.(type) { + case *ast.Ident: + if len(n.Name) <= _shortenExprLen { + return n.Name, true + } + case *ast.BasicLit: + if len(n.Value) <= _shortenExprLen { + return n.Value, true + } + } + return "", false + } + + switch node := e.(type) { + case *ast.Ident: + _, err = io.WriteString(writer, node.Name) + + case *ast.SelectorExpr: + if err = printExpr(writer, fset, node.X); err != nil { + return + } + _, err = io.WriteString(writer, "."+node.Sel.Name) + + case *ast.CallExpr: + if err = printExpr(writer, fset, node.Fun); err != nil { + return + } + var argStr string + if len(node.Args) > 0 { + argStr = "..." + if len(node.Args) == 1 { + if a, ok := fullExpr(node.Args[0]); ok { + argStr = a + } + } + } + _, err = io.WriteString(writer, "("+argStr+")") + + case *ast.IndexExpr: + if err = printExpr(writer, fset, node.X); err != nil { + return + } + + indexExpr := "..." + if v, ok := fullExpr(node.Index); ok { + indexExpr = v + } + _, err = io.WriteString(writer, "["+indexExpr+"]") + + default: + err = printer.Fprint(writer, fset, e) + } + return +} diff --git a/util/util.go b/util/util.go index 216f8cbe..cb296a76 100644 --- a/util/util.go +++ b/util/util.go @@ -16,10 +16,8 @@ package util import ( - "bytes" "fmt" "go/ast" - "go/printer" "go/token" "go/types" "regexp" @@ -487,13 +485,3 @@ func truncatePosition(position token.Position) token.Position { func PosToLocation(pos token.Pos, pass *analysis.Pass) token.Position { return truncatePosition(pass.Fset.Position(pos)) } - -// ExprToString converts AST expression to string -func ExprToString(e ast.Expr, pass *analysis.Pass) string { - var buf bytes.Buffer - err := printer.Fprint(&buf, pass.Fset, e) - if err != nil { - panic(fmt.Sprintf("Failed to convert AST expression to string: %v\n", err)) - } - return buf.String() -}