diff --git a/assertion/function/assertiontree/root_assertion_node.go b/assertion/function/assertiontree/root_assertion_node.go index 3fb61423..73eb7028 100644 --- a/assertion/function/assertiontree/root_assertion_node.go +++ b/assertion/function/assertiontree/root_assertion_node.go @@ -498,12 +498,7 @@ func (r *RootAssertionNode) AddGuardMatch(expr ast.Expr, behavior GuardMatchBeha case ContinueTracking: for i, consumer := range consumers { if consumer.Guards.Contains(guard) && !consumer.GuardMatched { - consumers[i] = &annotation.ConsumeTrigger{ - Annotation: consumer.Annotation, - Expr: consumer.Expr, - Guards: consumer.Guards, - GuardMatched: true, - } + consumers[i].GuardMatched = true } } case ProduceAsNonnil: diff --git a/assertion/function/testdata/src/go.uber.org/backprop/fixpoint.go b/assertion/function/testdata/src/go.uber.org/backprop/fixpoint.go index e72da237..9696bf10 100644 --- a/assertion/function/testdata/src/go.uber.org/backprop/fixpoint.go +++ b/assertion/function/testdata/src/go.uber.org/backprop/fixpoint.go @@ -125,3 +125,48 @@ func testNonBuiltinNestedIndex(msgSet []*MessageBlock) { // expect_fixpoint: 3 1 _ = *msgBlock.Messages()[len(msgBlock.Messages())-1] } } + +// test for validating that only the necessary number of triggers are created, and +// no extra triggers (e.g., deep triggers) are created. + +func foo(x *int) *int { // expect_fixpoint: 2 1 1 + if x == nil { + return nil + } + return new(int) +} + +func testContract() { // expect_fixpoint: 2 1 2 + a1 := new(int) + b1 := foo(a1) + print(*b1) +} + +func foo2(a *A) *A { // expect_fixpoint: 2 1 15 + if a == nil { + return nil + } + return a.aptr +} + +func testContract2() { // expect_fixpoint: 2 1 8 + b1 := foo2(&A{}) + print(*b1) +} + +type myString []*string + +// nilable(s[]) +func (s *myString) testNamedType() { // expect_fixpoint: 2 1 3 + x := *s + _ = *x[0] +} + +func testNestedPointer() { // expect_fixpoint: 4 2 4 + a1 := &A{} + for i := 0; i < 10; i++ { + a2 := &a1 + (*a2).ptr = new(int) + *a2 = nil + } +} diff --git a/util/util.go b/util/util.go index 8f47a32a..5d45735f 100644 --- a/util/util.go +++ b/util/util.go @@ -42,30 +42,23 @@ var BuiltinAppend = types.Universe.Lookup("append") // BuiltinNew is the builtin "new" function object. var BuiltinNew = types.Universe.Lookup("new") -// TypeIsDeep checks if a type is an expression that directly admits a deep nilability annotation - deep -// nilability annotations on all other types are ignored +// TypeIsDeep checks if a type is an expression that admits deep nilability, such as maps, slices, arrays, etc. +// Only consider pointers to deep types (e.g., `var x *[]int`) as deep type, +// not pointers to basic types (e.g., `var x *int`) or struct types (e.g., `var x *S`) func TypeIsDeep(t types.Type) bool { - _, isDeep := TypeAsDeepType(t) - return isDeep -} - -// TypeAsDeepType checks if a type is an expression that directly admits a deep nilability annotation, -// returning true as its boolean param if so, along with the element type as its `types.Type` param -// nilable(result 0) -func TypeAsDeepType(t types.Type) (types.Type, bool) { - switch t := t.(type) { - case *types.Slice: - return t.Elem(), true - case *types.Array: - return t.Elem(), true - case *types.Map: - return t.Elem(), true - case *types.Chan: - return t.Elem(), true - case *types.Pointer: - return t.Elem(), true + switch UnwrapPtr(t).(type) { + case *types.Slice, *types.Array, *types.Map, *types.Chan, *types.Struct: + return true + case *types.Basic: + return false } - return nil, false + if t, ok := t.(*types.Pointer); ok { + if TypeAsDeeplyStruct(t.Underlying()) == nil { + return true + } + } + + return false } // TypeIsSlice returns true if `t` is of slice type