Skip to content

Commit

Permalink
Assignment tracking for many-to-one assignments (#181)
Browse files Browse the repository at this point in the history
This PR adds assignment tracking for many-to-one assignments for
printing informative error messages, following suit of one-to-one
assignment tracking (PR #87 ).
  • Loading branch information
sonalmahajan15 authored Jan 26, 2024
1 parent 8b7423a commit 6b83875
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 25 deletions.
52 changes: 27 additions & 25 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/config"
"go.uber.org/nilaway/util"
"go.uber.org/nilaway/util/asthelper"
"golang.org/x/exp/slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/cfg"
Expand Down Expand Up @@ -591,20 +590,10 @@ buildShadowMask:
if ok && lhsNode != nil {
// Add assignment entries to the consumers of lhsNode for informative printing of errors
for _, c := range lhsNode.ConsumeTriggers() {
var lhsExprStr, rhsExprStr string
var err error
if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), c.Annotation)
if err != nil {
return err
}
if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
return err
}

c.Annotation.AddAssignment(annotation.Assignment{
LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())),
})
}

// If the lhsVal path is not only trackable but tracked, we add it as
Expand Down Expand Up @@ -645,20 +634,10 @@ buildShadowMask:
continue
}
for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] {
var lhsExprStr, rhsExprStr string
var err error
if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
return err
}
if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), t.Consumer.Annotation)
if err != nil {
return err
}

t.Consumer.Annotation.AddAssignment(annotation.Assignment{
LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())),
})
}
default:
return errors.New("rhs expression in a 1-1 assignment was multiply returning - " +
Expand Down Expand Up @@ -734,18 +713,36 @@ func backpropAcrossManyToOneAssignment(rootNode *RootAssertionNode, lhs, rhs []a
rootNode.addProductionsForAssignmentFields(fieldProducers, lhsVal)
}

// beforeTriggersLastIndex is used to find the newly added triggers on the next line
beforeTriggersLastIndex := len(rootNode.triggers)

rootNode.AddGuardMatch(lhsVal, ContinueTracking)
rootNode.AddProduction(&annotation.ProduceTrigger{
Annotation: producers[i].GetShallow().Annotation,
Expr: lhsVal,
}, producers[i].GetDeepSlice()...)

// Update consumers of newly added triggers with assignment entries for informative printing of errors
if len(rootNode.triggers) > 0 {
for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), t.Consumer.Annotation)
if err != nil {
return err
}
}
}

// Phase 2
consumeTrigger, err := exprAsAssignmentConsumer(rootNode, lhsVal, rhsVal)
if err != nil {
return err
}
if consumeTrigger != nil {
// Update consumeTrigger with assignment entries for informative printing of errors
if err = addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), consumeTrigger); err != nil {
return err
}

// lhsVal is a field read, so this is a field assignment
// since multiple return functions aren't trackable, this is a completed trigger
// as long as the type of the expression being assigned doesn't bar nilness
Expand All @@ -767,6 +764,11 @@ func backpropAcrossManyToOneAssignment(rootNode *RootAssertionNode, lhs, rhs []a
}

if consumer := exprAsConsumedByAssignment(rootNode, lhsVal); consumer != nil {
// Update consumeTrigger with assignment entries for informative printing of errors
if err = addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), consumer.Annotation); err != nil {
return err
}

rootNode.AddConsumption(consumer)
}
}
Expand Down
22 changes: 22 additions & 0 deletions assertion/function/assertiontree/backprop_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/util"
"go.uber.org/nilaway/util/asthelper"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/cfg"
)
Expand Down Expand Up @@ -749,3 +750,24 @@ func CheckGuardOnFullTrigger(trigger annotation.FullTrigger) annotation.FullTrig
}
return trigger
}

// addAssignmentToConsumer updates the consumer with assignment entries for informative printing of errors
func addAssignmentToConsumer(lhs, rhs ast.Expr, pass *analysis.Pass, consumer annotation.ConsumingAnnotationTrigger) error {
var lhsExprStr, rhsExprStr string
var err error

if lhsExprStr, err = asthelper.PrintExpr(lhs, pass, true /* isShortenExpr */); err != nil {
return fmt.Errorf("converting LHS of assignment to string: %w", err)
}
if rhsExprStr, err = asthelper.PrintExpr(rhs, pass, true /* isShortenExpr */); err != nil {
return fmt.Errorf("converting RHS of assignment to string: %w", err)
}

consumer.AddAssignment(annotation.Assignment{
LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhs.Pos(), pass)),
})

return nil
}
93 changes: 93 additions & 0 deletions testdata/src/go.uber.org/errormessage/errormessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,96 @@ func test21() {
var 世界 *int = nil
print(*世界) //want "`nil` to `世界`"
}

// below tests check assignment flow tracking across many-to-one assignments

// nilable(result 0)
func retPtrErr() (*int, error) {
return nil, nil
}

func test22(i int) {
switch i {
case 0:
x, err := retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"

case 1:
if x, err := retPtrErr(); err == nil {
y := x
print(*y) //want "`retPtrErr\\(\\)` to `x`"
}

case 2:
var x *int
var err error
x, err = retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"

case 3:
var x, err = retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"
}
}

// nilable(mp[])
func test23(mp map[int]*int, i int) {
switch i {
case 0:
v, ok := mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}

case 1:
if v, ok := mp[0]; ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
case 2:
var v *int
var ok bool
v, ok = mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
case 3:
var v, ok = mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
}
}

// nilable(result 0, result 2)
func retMultiple() (*int, *int, *int) {
return nil, new(int), nil
}

func test24() {
a, b, c := retMultiple()
if dummy {
b = a
}
print(*a) //want "`retMultiple\\(\\)` to `a`"
print(*b) //want "`a` to `b`"
print(*c) //want "`retMultiple\\(\\)` to `c`"
}

// nilable(A[])
type A []*int

// nonnil(a)
func test25(a A) {
a[0], a[1], _ = retMultiple()
print(*a[0]) //want "`retMultiple\\(\\)` to `a\\[0\\]`"
print(*a[1])
}

0 comments on commit 6b83875

Please sign in to comment.