From b2995c608c55a3b6603e673e087e225567907920 Mon Sep 17 00:00:00 2001 From: Sonal Mahajan <101232472+sonalmahajan15@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:37:16 -0800 Subject: [PATCH] Add deep copy for consumers (#86) This PR modifies the consumers to add support for deep copy when the consumers are being copied in backpropagation. [depends on #80 ] --- annotation/consume_trigger.go | 813 +++++++++++++++--- annotation/consume_trigger_test.go | 108 ++- annotation/copy_test.go | 121 +++ annotation/equals_test.go | 211 +---- annotation/helper_test.go | 250 ++++++ annotation/key.go | 58 ++ annotation/key_test.go | 60 +- assertion/function/assertiontree/backprop.go | 49 +- .../assertiontree/root_assertion_node.go | 2 - assertion/function/assertiontree/util.go | 6 +- nilaway_test.go | 7 + .../go.uber.org/errormessage/errormessage.go | 264 ++++++ util/asthelper/asthelper.go | 104 +++ 13 files changed, 1641 insertions(+), 412 deletions(-) create mode 100644 annotation/copy_test.go create mode 100644 annotation/helper_test.go create mode 100644 testdata/src/go.uber.org/errormessage/errormessage.go create mode 100644 util/asthelper/asthelper.go diff --git a/annotation/consume_trigger.go b/annotation/consume_trigger.go index ae12353f..161eb51c 100644 --- a/annotation/consume_trigger.go +++ b/annotation/consume_trigger.go @@ -22,6 +22,7 @@ import ( "strings" "go.uber.org/nilaway/util" + "go.uber.org/nilaway/util/orderedmap" ) // A ConsumingAnnotationTrigger indicated a possible reason that a nil flow to this site would indicate @@ -50,6 +51,15 @@ type ConsumingAnnotationTrigger interface { // equals returns true if the passed ConsumingAnnotationTrigger is equal to this one equals(ConsumingAnnotationTrigger) bool + + // Copy returns a deep copy of this ConsumingAnnotationTrigger + Copy() ConsumingAnnotationTrigger + + // AddAssignment adds an assignment to the trigger for tracking and printing informative error message. + // NilAway's `backpropAcrossOneToOneAssignment()` lifts consumer triggers from the RHS of an assignment to the LHS. + // This implies loss of information about the assignment. This method is used to track such assignments and print + // a more informative error message. + AddAssignment(Assignment) } // customPos has the below default implementations, in which case ConsumeTrigger.Pos() will return a default value. @@ -65,9 +75,85 @@ type Prestring interface { String() string } +// Assignment is a struct that represents an assignment to an expression +type Assignment struct { + LHSExprStr string + RHSExprStr string + Position token.Position +} + +func (a *Assignment) String() string { + return fmt.Sprintf("`%s` to `%s` at %s", a.RHSExprStr, a.LHSExprStr, a.Position) +} + +// assignmentFlow is a struct that represents a flow of assignments. +// Note that we implement a copy method for this struct, since we want to deep copy the assignments map when we copy +// ConsumerTriggers. However, we don't implement an `equals` method for this struct, since it would incur a performance +// penalty in situations where multiple nilable flows reach a dereference site by creating more full triggers and possibly +// more rounds through backpropagation fix point. Consider the following example: +// +// func f(m map[int]*int) { +// var v *int +// var ok1, ok2 bool +// if cond { +// v, ok1 = m[0] // nilable flow 1, ok1 is false +// } else { +// v, ok2 = m[1] // nilable flow 2, ok2 is false +// } +// _, _ = ok1, ok2 +// _ = *v // nil panic! +// } +// +// Here `v` can be potentiall nilable from two flows: ok1 or ok2 is false. We would like to print only one error message +// for this situation with one representative flow printed in the error message. However, with an `equals` method, we would +// report multiple error messages, one for each flow, by creating multiple full triggers, thereby affecting performance. +type assignmentFlow struct { + // We use ordered map for `assignments` to maintain the order of assignments in the flow, and also to avoid + // duplicates that can get introduced due to fix point convergence in backpropagation. + assignments *orderedmap.OrderedMap[Assignment, bool] +} + +func (a *assignmentFlow) addEntry(entry Assignment) { + if a.assignments == nil { + a.assignments = orderedmap.New[Assignment, bool]() + } + a.assignments.Store(entry, true) +} + +func (a *assignmentFlow) copy() assignmentFlow { + if a.assignments == nil { + return assignmentFlow{} + } + assignments := orderedmap.New[Assignment, bool]() + for _, p := range a.assignments.Pairs { + assignments.Store(p.Key, true) + } + return assignmentFlow{assignments: assignments} +} + +func (a *assignmentFlow) String() string { + if a.assignments == nil || len(a.assignments.Pairs) == 0 { + return "" + } + + // backprop algorithm populates assignment entries in backward order. Reverse entries to get forward order of + // assignments, and store in `strs` slice. + strs := make([]string, 0, len(a.assignments.Pairs)) + for i := len(a.assignments.Pairs) - 1; i >= 0; i-- { + strs = append(strs, a.assignments.Pairs[i].Key.String()) + } + + // build the informative print string tracking the assignments + var sb strings.Builder + sb.WriteString(" via the assignment(s):\n\t\t-> ") + sb.WriteString(strings.Join(strs, ",\n\t\t-> ")) + return sb.String() +} + // TriggerIfNonNil is triggered if the contained Annotation is non-nil type TriggerIfNonNil struct { Ann Key + assignmentFlow } // Kind returns Conditional. @@ -90,21 +176,42 @@ func (t *TriggerIfNonNil) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (t *TriggerIfNonNil) Copy() ConsumingAnnotationTrigger { + copyConsumer := *t + copyConsumer.Ann = t.Ann.copy() + copyConsumer.assignmentFlow = t.assignmentFlow.copy() + return ©Consumer +} + +// AddAssignment adds an assignment to the trigger. +func (t *TriggerIfNonNil) AddAssignment(e Assignment) { + t.assignmentFlow.addEntry(e) +} + // Prestring returns this Prestring as a Prestring -func (*TriggerIfNonNil) Prestring() Prestring { - return TriggerIfNonNilPrestring{} +func (t *TriggerIfNonNil) Prestring() Prestring { + return TriggerIfNonNilPrestring{ + AssignmentStr: t.assignmentFlow.String(), + } } // TriggerIfNonNilPrestring is a Prestring storing the needed information to compactly encode a TriggerIfNonNil -type TriggerIfNonNilPrestring struct{} +type TriggerIfNonNilPrestring struct { + AssignmentStr string +} -func (TriggerIfNonNilPrestring) String() string { - return "nonnil value" +func (t TriggerIfNonNilPrestring) String() string { + var sb strings.Builder + sb.WriteString("nonnil value") + sb.WriteString(t.AssignmentStr) + return sb.String() } // TriggerIfDeepNonNil is triggered if the contained Annotation is deeply non-nil type TriggerIfDeepNonNil struct { Ann Key + assignmentFlow } // Kind returns DeepConditional. @@ -127,20 +234,42 @@ func (t *TriggerIfDeepNonNil) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (t *TriggerIfDeepNonNil) Copy() ConsumingAnnotationTrigger { + copyConsumer := *t + copyConsumer.Ann = t.Ann.copy() + copyConsumer.assignmentFlow = t.assignmentFlow.copy() + return ©Consumer +} + +// AddAssignment adds an assignment to the trigger. +func (t *TriggerIfDeepNonNil) AddAssignment(e Assignment) { + t.assignmentFlow.addEntry(e) +} + // Prestring returns this Prestring as a Prestring -func (*TriggerIfDeepNonNil) Prestring() Prestring { - return TriggerIfDeepNonNilPrestring{} +func (t *TriggerIfDeepNonNil) Prestring() Prestring { + return TriggerIfDeepNonNilPrestring{ + AssignmentStr: t.assignmentFlow.String(), + } } // TriggerIfDeepNonNilPrestring is a Prestring storing the needed information to compactly encode a TriggerIfDeepNonNil -type TriggerIfDeepNonNilPrestring struct{} +type TriggerIfDeepNonNilPrestring struct { + AssignmentStr string +} -func (TriggerIfDeepNonNilPrestring) String() string { - return "deeply nonnil value" +func (t TriggerIfDeepNonNilPrestring) String() string { + var sb strings.Builder + sb.WriteString("deeply nonnil value") + sb.WriteString(t.AssignmentStr) + return sb.String() } // ConsumeTriggerTautology is used at consumption sites were consuming nil is always an error -type ConsumeTriggerTautology struct{} +type ConsumeTriggerTautology struct { + assignmentFlow +} // Kind returns Always. func (*ConsumeTriggerTautology) Kind() TriggerKind { return Always } @@ -157,16 +286,35 @@ func (*ConsumeTriggerTautology) equals(other ConsumingAnnotationTrigger) bool { return ok } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (t *ConsumeTriggerTautology) Copy() ConsumingAnnotationTrigger { + copyConsumer := *t + copyConsumer.assignmentFlow = t.assignmentFlow.copy() + return ©Consumer +} + +// AddAssignment adds an assignment to the trigger. +func (t *ConsumeTriggerTautology) AddAssignment(e Assignment) { + t.assignmentFlow.addEntry(e) +} + // Prestring returns this Prestring as a Prestring -func (*ConsumeTriggerTautology) Prestring() Prestring { - return ConsumeTriggerTautologyPrestring{} +func (t *ConsumeTriggerTautology) Prestring() Prestring { + return ConsumeTriggerTautologyPrestring{ + AssignmentStr: t.assignmentFlow.String(), + } } // ConsumeTriggerTautologyPrestring is a Prestring storing the needed information to compactly encode a ConsumeTriggerTautology -type ConsumeTriggerTautologyPrestring struct{} +type ConsumeTriggerTautologyPrestring struct { + AssignmentStr string +} -func (ConsumeTriggerTautologyPrestring) String() string { - return "must be nonnil" +func (c ConsumeTriggerTautologyPrestring) String() string { + var sb strings.Builder + sb.WriteString("must be nonnil") + sb.WriteString(c.AssignmentStr) + return sb.String() } // PtrLoad is when a value flows to a point where it is loaded as a pointer @@ -182,16 +330,30 @@ func (p *PtrLoad) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (p *PtrLoad) Copy() ConsumingAnnotationTrigger { + copyConsumer := *p + copyConsumer.ConsumeTriggerTautology = p.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this PtrLoad as a Prestring func (p *PtrLoad) Prestring() Prestring { - return PtrLoadPrestring{} + return PtrLoadPrestring{ + AssignmentStr: p.assignmentFlow.String(), + } } // PtrLoadPrestring is a Prestring storing the needed information to compactly encode a PtrLoad -type PtrLoadPrestring struct{} +type PtrLoadPrestring struct { + AssignmentStr string +} -func (PtrLoadPrestring) String() string { - return "dereferenced" +func (p PtrLoadPrestring) String() string { + var sb strings.Builder + sb.WriteString("dereferenced") + sb.WriteString(p.AssignmentStr) + return sb.String() } // MapAccess is when a map value flows to a point where it is indexed, and thus must be non-nil @@ -209,16 +371,30 @@ func (i *MapAccess) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (i *MapAccess) Copy() ConsumingAnnotationTrigger { + copyConsumer := *i + copyConsumer.ConsumeTriggerTautology = i.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this MapAccess as a Prestring func (i *MapAccess) Prestring() Prestring { - return MapAccessPrestring{} + return MapAccessPrestring{ + AssignmentStr: i.assignmentFlow.String(), + } } // MapAccessPrestring is a Prestring storing the needed information to compactly encode a MapAccess -type MapAccessPrestring struct{} +type MapAccessPrestring struct { + AssignmentStr string +} -func (MapAccessPrestring) String() string { - return "keyed into" +func (i MapAccessPrestring) String() string { + var sb strings.Builder + sb.WriteString("keyed into") + sb.WriteString(i.AssignmentStr) + return sb.String() } // MapWrittenTo is when a map value flows to a point where one of its indices is written to, and thus @@ -235,16 +411,30 @@ func (m *MapWrittenTo) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (m *MapWrittenTo) Copy() ConsumingAnnotationTrigger { + copyConsumer := *m + copyConsumer.ConsumeTriggerTautology = m.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this MapWrittenTo as a Prestring func (m *MapWrittenTo) Prestring() Prestring { - return MapWrittenToPrestring{} + return MapWrittenToPrestring{ + AssignmentStr: m.assignmentFlow.String(), + } } // MapWrittenToPrestring is a Prestring storing the needed information to compactly encode a MapWrittenTo -type MapWrittenToPrestring struct{} +type MapWrittenToPrestring struct { + AssignmentStr string +} -func (MapWrittenToPrestring) String() string { - return "written to at an index" +func (m MapWrittenToPrestring) String() string { + var sb strings.Builder + sb.WriteString("written to at an index") + sb.WriteString(m.AssignmentStr) + return sb.String() } // SliceAccess is when a slice value flows to a point where it is sliced, and thus must be non-nil @@ -260,16 +450,30 @@ func (s *SliceAccess) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (s *SliceAccess) Copy() ConsumingAnnotationTrigger { + copyConsumer := *s + copyConsumer.ConsumeTriggerTautology = s.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this SliceAccess as a Prestring func (s *SliceAccess) Prestring() Prestring { - return SliceAccessPrestring{} + return SliceAccessPrestring{ + AssignmentStr: s.assignmentFlow.String(), + } } // SliceAccessPrestring is a Prestring storing the needed information to compactly encode a SliceAccess -type SliceAccessPrestring struct{} +type SliceAccessPrestring struct { + AssignmentStr string +} -func (SliceAccessPrestring) String() string { - return "sliced into" +func (s SliceAccessPrestring) String() string { + var sb strings.Builder + sb.WriteString("sliced into") + sb.WriteString(s.AssignmentStr) + return sb.String() } // FldAccess is when a value flows to a point where a field of it is accessed, and so it must be non-nil @@ -287,6 +491,13 @@ func (f *FldAccess) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *FldAccess) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.ConsumeTriggerTautology = f.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this FldAccess as a Prestring func (f *FldAccess) Prestring() Prestring { fieldName, methodName := "", "" @@ -300,22 +511,28 @@ func (f *FldAccess) Prestring() Prestring { } return FldAccessPrestring{ - FieldName: fieldName, - MethodName: methodName, + FieldName: fieldName, + MethodName: methodName, + AssignmentStr: f.assignmentFlow.String(), } } // FldAccessPrestring is a Prestring storing the needed information to compactly encode a FldAccess type FldAccessPrestring struct { - FieldName string - MethodName string + FieldName string + MethodName string + AssignmentStr string } func (f FldAccessPrestring) String() string { + var sb strings.Builder if f.MethodName != "" { - return fmt.Sprintf("called `%s()`", f.MethodName) + sb.WriteString(fmt.Sprintf("called `%s()`", f.MethodName)) + } else { + sb.WriteString(fmt.Sprintf("accessed field `%s`", f.FieldName)) } - return fmt.Sprintf("accessed field `%s`", f.FieldName) + sb.WriteString(f.AssignmentStr) + return sb.String() } // UseAsErrorResult is when a value flows to the error result of a function, where it is expected to be non-nil @@ -336,6 +553,13 @@ func (u *UseAsErrorResult) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (u *UseAsErrorResult) Copy() ConsumingAnnotationTrigger { + copyConsumer := *u + copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this UseAsErrorResult as a Prestring func (u *UseAsErrorResult) Prestring() Prestring { retAnn := u.Ann.(*RetAnnotationKey) @@ -344,6 +568,7 @@ func (u *UseAsErrorResult) Prestring() Prestring { ReturningFuncStr: retAnn.FuncDecl.Name(), IsNamedReturn: u.IsNamedReturn, RetName: retAnn.FuncDecl.Type().(*types.Signature).Results().At(retAnn.RetNum).Name(), + AssignmentStr: u.assignmentFlow.String(), } } @@ -353,13 +578,18 @@ type UseAsErrorResultPrestring struct { ReturningFuncStr string IsNamedReturn bool RetName string + AssignmentStr string } func (u UseAsErrorResultPrestring) String() string { + var sb strings.Builder if u.IsNamedReturn { - return fmt.Sprintf("returned as named error result `%s` of `%s()`", u.RetName, u.ReturningFuncStr) + sb.WriteString(fmt.Sprintf("returned as named error result `%s` of `%s()`", u.RetName, u.ReturningFuncStr)) + } else { + sb.WriteString(fmt.Sprintf("returned as error result %d of `%s()`", u.Pos, u.ReturningFuncStr)) } - return fmt.Sprintf("returned as error result %d of `%s()`", u.Pos, u.ReturningFuncStr) + sb.WriteString(u.AssignmentStr) + return sb.String() } // overriding position value to point to the raw return statement, which is the source of the potential error @@ -383,21 +613,33 @@ func (f *FldAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *FldAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfNonNil = f.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this FldAssign as a Prestring func (f *FldAssign) Prestring() Prestring { fldAnn := f.Ann.(*FieldAnnotationKey) return FldAssignPrestring{ - FieldName: fldAnn.FieldDecl.Name(), + FieldName: fldAnn.FieldDecl.Name(), + AssignmentStr: f.assignmentFlow.String(), } } // FldAssignPrestring is a Prestring storing the needed information to compactly encode a FldAssign type FldAssignPrestring struct { - FieldName string + FieldName string + AssignmentStr string } func (f FldAssignPrestring) String() string { - return fmt.Sprintf("assigned into field `%s`", f.FieldName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into field `%s`", f.FieldName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // ArgFldPass is when a struct field value (A.f) flows to a point where it is passed to a function with a param of @@ -415,6 +657,13 @@ func (f *ArgFldPass) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *ArgFldPass) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfNonNil = f.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this ArgFldPass as a Prestring func (f *ArgFldPass) Prestring() Prestring { ann := f.Ann.(*ParamFieldAnnotationKey) @@ -424,33 +673,40 @@ func (f *ArgFldPass) Prestring() Prestring { } return ArgFldPassPrestring{ - FieldName: ann.FieldDecl.Name(), - FuncName: ann.FuncDecl.Name(), - ParamNum: ann.ParamNum, - RecvName: recvName, - IsPassed: f.IsPassed, + FieldName: ann.FieldDecl.Name(), + FuncName: ann.FuncDecl.Name(), + ParamNum: ann.ParamNum, + RecvName: recvName, + IsPassed: f.IsPassed, + AssignmentStr: f.assignmentFlow.String(), } } // ArgFldPassPrestring is a Prestring storing the needed information to compactly encode a ArgFldPass type ArgFldPassPrestring struct { - FieldName string - FuncName string - ParamNum int - RecvName string - IsPassed bool + FieldName string + FuncName string + ParamNum int + RecvName string + IsPassed bool + AssignmentStr string } func (f ArgFldPassPrestring) String() string { + var sb strings.Builder prefix := "" if f.IsPassed { prefix = "assigned to " } if len(f.RecvName) > 0 { - return fmt.Sprintf("%sfield `%s` of method receiver `%s`", prefix, f.FieldName, f.RecvName) + sb.WriteString(fmt.Sprintf("%sfield `%s` of method receiver `%s`", prefix, f.FieldName, f.RecvName)) + } else { + sb.WriteString(fmt.Sprintf("%sfield `%s` of argument %d to `%s()`", prefix, f.FieldName, f.ParamNum, f.FuncName)) } - return fmt.Sprintf("%sfield `%s` of argument %d to `%s()`", prefix, f.FieldName, f.ParamNum, f.FuncName) + + sb.WriteString(f.AssignmentStr) + return sb.String() } // GlobalVarAssign is when a value flows to a point where it is assigned into a global variable @@ -466,21 +722,33 @@ func (g *GlobalVarAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (g *GlobalVarAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *g + copyConsumer.TriggerIfNonNil = g.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this GlobalVarAssign as a Prestring func (g *GlobalVarAssign) Prestring() Prestring { varAnn := g.Ann.(*GlobalVarAnnotationKey) return GlobalVarAssignPrestring{ - VarName: varAnn.VarDecl.Name(), + VarName: varAnn.VarDecl.Name(), + AssignmentStr: g.assignmentFlow.String(), } } // GlobalVarAssignPrestring is a Prestring storing the needed information to compactly encode a GlobalVarAssign type GlobalVarAssignPrestring struct { - VarName string + VarName string + AssignmentStr string } func (g GlobalVarAssignPrestring) String() string { - return fmt.Sprintf("assigned into global variable `%s`", g.VarName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into global variable `%s`", g.VarName)) + sb.WriteString(g.AssignmentStr) + return sb.String() } // ArgPass is when a value flows to a point where it is passed as an argument to a function. This @@ -501,20 +769,29 @@ func (a *ArgPass) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (a *ArgPass) Copy() ConsumingAnnotationTrigger { + copyConsumer := *a + copyConsumer.TriggerIfNonNil = a.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this ArgPass as a Prestring func (a *ArgPass) Prestring() Prestring { switch key := a.Ann.(type) { case *ParamAnnotationKey: return ArgPassPrestring{ - ParamName: key.MinimalString(), - FuncName: key.FuncDecl.Name(), - Location: "", + ParamName: key.MinimalString(), + FuncName: key.FuncDecl.Name(), + Location: "", + AssignmentStr: a.assignmentFlow.String(), } case *CallSiteParamAnnotationKey: return ArgPassPrestring{ - ParamName: key.MinimalString(), - FuncName: key.FuncDecl.Name(), - Location: key.Location.String(), + ParamName: key.MinimalString(), + FuncName: key.FuncDecl.Name(), + Location: key.Location.String(), + AssignmentStr: a.assignmentFlow.String(), } default: panic(fmt.Sprintf( @@ -528,7 +805,8 @@ type ArgPassPrestring struct { FuncName string // Location points to the code location of the argument pass at the call site for a ArgPass // enclosing CallSiteParamAnnotationKey; Location is empty for a ArgPass enclosing ParamAnnotationKey. - Location string + Location string + AssignmentStr string } func (a ArgPassPrestring) String() string { @@ -537,6 +815,7 @@ func (a ArgPassPrestring) String() string { if a.Location != "" { sb.WriteString(fmt.Sprintf(" at %s", a.Location)) } + sb.WriteString(a.AssignmentStr) return sb.String() } @@ -554,21 +833,33 @@ func (a *RecvPass) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (a *RecvPass) Copy() ConsumingAnnotationTrigger { + copyConsumer := *a + copyConsumer.TriggerIfNonNil = a.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this RecvPass as a Prestring func (a *RecvPass) Prestring() Prestring { recvAnn := a.Ann.(*RecvAnnotationKey) return RecvPassPrestring{ - FuncName: recvAnn.FuncDecl.Name(), + FuncName: recvAnn.FuncDecl.Name(), + AssignmentStr: a.assignmentFlow.String(), } } // RecvPassPrestring is a Prestring storing the needed information to compactly encode a RecvPass type RecvPassPrestring struct { - FuncName string + FuncName string + AssignmentStr string } func (a RecvPassPrestring) String() string { - return fmt.Sprintf("used as receiver to call `%s()`", a.FuncName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("used as receiver to call `%s()`", a.FuncName)) + sb.WriteString(a.AssignmentStr) + return sb.String() } // InterfaceResultFromImplementation is when a result is determined to flow from a concrete method to an interface method via implementation @@ -587,6 +878,13 @@ func (i *InterfaceResultFromImplementation) equals(other ConsumingAnnotationTrig return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (i *InterfaceResultFromImplementation) Copy() ConsumingAnnotationTrigger { + copyConsumer := *i + copyConsumer.TriggerIfNonNil = i.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this InterfaceResultFromImplementation as a Prestring func (i *InterfaceResultFromImplementation) Prestring() Prestring { retAnn := i.Ann.(*RetAnnotationKey) @@ -594,19 +892,24 @@ func (i *InterfaceResultFromImplementation) Prestring() Prestring { retAnn.RetNum, util.PartiallyQualifiedFuncName(retAnn.FuncDecl), util.PartiallyQualifiedFuncName(i.ImplementingMethod), + i.assignmentFlow.String(), } } // InterfaceResultFromImplementationPrestring is a Prestring storing the needed information to compactly encode a InterfaceResultFromImplementation type InterfaceResultFromImplementationPrestring struct { - RetNum int - IntName string - ImplName string + RetNum int + IntName string + ImplName string + AssignmentStr string } func (i InterfaceResultFromImplementationPrestring) String() string { - return fmt.Sprintf("returned as result %d from interface method `%s()` (implemented by `%s()`)", - i.RetNum, i.IntName, i.ImplName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("returned as result %d from interface method `%s()` (implemented by `%s()`)", + i.RetNum, i.IntName, i.ImplName)) + sb.WriteString(i.AssignmentStr) + return sb.String() } // MethodParamFromInterface is when a param flows from an interface method to a concrete method via implementation @@ -625,6 +928,13 @@ func (m *MethodParamFromInterface) equals(other ConsumingAnnotationTrigger) bool return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (m *MethodParamFromInterface) Copy() ConsumingAnnotationTrigger { + copyConsumer := *m + copyConsumer.TriggerIfNonNil = m.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this MethodParamFromInterface as a Prestring func (m *MethodParamFromInterface) Prestring() Prestring { paramAnn := m.Ann.(*ParamAnnotationKey) @@ -632,19 +942,24 @@ func (m *MethodParamFromInterface) Prestring() Prestring { paramAnn.ParamNameString(), util.PartiallyQualifiedFuncName(paramAnn.FuncDecl), util.PartiallyQualifiedFuncName(m.InterfaceMethod), + m.assignmentFlow.String(), } } // MethodParamFromInterfacePrestring is a Prestring storing the needed information to compactly encode a MethodParamFromInterface type MethodParamFromInterfacePrestring struct { - ParamName string - ImplName string - IntName string + ParamName string + ImplName string + IntName string + AssignmentStr string } func (m MethodParamFromInterfacePrestring) String() string { - return fmt.Sprintf("passed as parameter `%s` to `%s()` (implementing `%s()`)", - m.ParamName, m.ImplName, m.IntName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("passed as parameter `%s` to `%s()` (implementing `%s()`)", + m.ParamName, m.ImplName, m.IntName)) + sb.WriteString(m.AssignmentStr) + return sb.String() } // DuplicateReturnConsumer duplicates a given consume trigger, assuming the given consumer trigger @@ -686,6 +1001,13 @@ func (u *UseAsReturn) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (u *UseAsReturn) Copy() ConsumingAnnotationTrigger { + copyConsumer := *u + copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this UseAsReturn as a Prestring func (u *UseAsReturn) Prestring() Prestring { switch key := u.Ann.(type) { @@ -696,6 +1018,7 @@ func (u *UseAsReturn) Prestring() Prestring { u.IsNamedReturn, key.FuncDecl.Type().(*types.Signature).Results().At(key.RetNum).Name(), "", + u.assignmentFlow.String(), } case *CallSiteRetAnnotationKey: return UseAsReturnPrestring{ @@ -704,6 +1027,7 @@ func (u *UseAsReturn) Prestring() Prestring { u.IsNamedReturn, key.FuncDecl.Type().(*types.Signature).Results().At(key.RetNum).Name(), key.Location.String(), + u.assignmentFlow.String(), } default: panic(fmt.Sprintf("Expected RetAnnotationKey or CallSiteRetAnnotationKey but got: %T", key)) @@ -719,7 +1043,8 @@ type UseAsReturnPrestring struct { // Location is empty for a UseAsReturn enclosing RetAnnotationKey. Location points to the // location of the result at the call site for a UseAsReturn enclosing // CallSiteRetAnnotationKey. - Location string + Location string + AssignmentStr string } func (u UseAsReturnPrestring) String() string { @@ -733,6 +1058,7 @@ func (u UseAsReturnPrestring) String() string { if u.Location != "" { sb.WriteString(fmt.Sprintf(" at %s", u.Location)) } + sb.WriteString(u.AssignmentStr) return sb.String() } @@ -758,6 +1084,13 @@ func (u *UseAsFldOfReturn) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (u *UseAsFldOfReturn) Copy() ConsumingAnnotationTrigger { + copyConsumer := *u + copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this UseAsFldOfReturn as a Prestring func (u *UseAsFldOfReturn) Prestring() Prestring { retAnn := u.Ann.(*RetFieldAnnotationKey) @@ -765,18 +1098,23 @@ func (u *UseAsFldOfReturn) Prestring() Prestring { retAnn.FuncDecl.Name(), retAnn.FieldDecl.Name(), retAnn.RetNum, + u.assignmentFlow.String(), } } // UseAsFldOfReturnPrestring is a Prestring storing the needed information to compactly encode a UseAsFldOfReturn type UseAsFldOfReturnPrestring struct { - FuncName string - FieldName string - RetNum int + FuncName string + FieldName string + RetNum int + AssignmentStr string } func (u UseAsFldOfReturnPrestring) String() string { - return fmt.Sprintf("field `%s` returned by result %d of `%s()`", u.FieldName, u.RetNum, u.FuncName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("field `%s` returned by result %d of `%s()`", u.FieldName, u.RetNum, u.FuncName)) + sb.WriteString(u.AssignmentStr) + return sb.String() } // GetRetFldConsumer returns the UseAsFldOfReturn consume trigger with given retKey and expr @@ -828,21 +1166,33 @@ func (f *SliceAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *SliceAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfDeepNonNil = f.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this SliceAssign as a Prestring func (f *SliceAssign) Prestring() Prestring { fldAnn := f.Ann.(*TypeNameAnnotationKey) return SliceAssignPrestring{ fldAnn.TypeDecl.Name(), + f.assignmentFlow.String(), } } // SliceAssignPrestring is a Prestring storing the needed information to compactly encode a SliceAssign type SliceAssignPrestring struct { - TypeName string + TypeName string + AssignmentStr string } func (f SliceAssignPrestring) String() string { - return fmt.Sprintf("assigned into a slice of deeply nonnil type `%s`", f.TypeName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into a slice of deeply nonnil type `%s`", f.TypeName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // ArrayAssign is when a value flows to a point where it is assigned into an array @@ -858,21 +1208,33 @@ func (a *ArrayAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (a *ArrayAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *a + copyConsumer.TriggerIfDeepNonNil = a.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this ArrayAssign as a Prestring func (a *ArrayAssign) Prestring() Prestring { fldAnn := a.Ann.(*TypeNameAnnotationKey) return ArrayAssignPrestring{ fldAnn.TypeDecl.Name(), + a.assignmentFlow.String(), } } // ArrayAssignPrestring is a Prestring storing the needed information to compactly encode a SliceAssign type ArrayAssignPrestring struct { - TypeName string + TypeName string + AssignmentStr string } func (a ArrayAssignPrestring) String() string { - return fmt.Sprintf("assigned into an array of deeply nonnil type `%s`", a.TypeName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into an array of deeply nonnil type `%s`", a.TypeName)) + sb.WriteString(a.AssignmentStr) + return sb.String() } // PtrAssign is when a value flows to a point where it is assigned into a pointer @@ -888,21 +1250,33 @@ func (f *PtrAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *PtrAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfDeepNonNil = f.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this PtrAssign as a Prestring func (f *PtrAssign) Prestring() Prestring { fldAnn := f.Ann.(*TypeNameAnnotationKey) return PtrAssignPrestring{ fldAnn.TypeDecl.Name(), + f.assignmentFlow.String(), } } // PtrAssignPrestring is a Prestring storing the needed information to compactly encode a PtrAssign type PtrAssignPrestring struct { - TypeName string + TypeName string + AssignmentStr string } func (f PtrAssignPrestring) String() string { - return fmt.Sprintf("assigned into a pointer of deeply nonnil type `%s`", f.TypeName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into a pointer of deeply nonnil type `%s`", f.TypeName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // MapAssign is when a value flows to a point where it is assigned into an annotated map @@ -918,21 +1292,33 @@ func (f *MapAssign) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *MapAssign) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfDeepNonNil = f.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this MapAssign as a Prestring func (f *MapAssign) Prestring() Prestring { fldAnn := f.Ann.(*TypeNameAnnotationKey) return MapAssignPrestring{ fldAnn.TypeDecl.Name(), + f.assignmentFlow.String(), } } // MapAssignPrestring is a Prestring storing the needed information to compactly encode a MapAssign type MapAssignPrestring struct { - TypeName string + TypeName string + AssignmentStr string } func (f MapAssignPrestring) String() string { - return fmt.Sprintf("assigned into a map of deeply nonnil type `%s`", f.TypeName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned into a map of deeply nonnil type `%s`", f.TypeName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // DeepAssignPrimitive is when a value flows to a point where it is assigned @@ -949,16 +1335,30 @@ func (d *DeepAssignPrimitive) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (d *DeepAssignPrimitive) Copy() ConsumingAnnotationTrigger { + copyConsumer := *d + copyConsumer.ConsumeTriggerTautology = d.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this Prestring as a Prestring -func (*DeepAssignPrimitive) Prestring() Prestring { - return DeepAssignPrimitivePrestring{} +func (d *DeepAssignPrimitive) Prestring() Prestring { + return DeepAssignPrimitivePrestring{ + AssignmentStr: d.assignmentFlow.String(), + } } // DeepAssignPrimitivePrestring is a Prestring storing the needed information to compactly encode a DeepAssignPrimitive -type DeepAssignPrimitivePrestring struct{} +type DeepAssignPrimitivePrestring struct { + AssignmentStr string +} -func (DeepAssignPrimitivePrestring) String() string { - return "assigned into a deep type expecting nonnil element type" +func (d DeepAssignPrimitivePrestring) String() string { + var sb strings.Builder + sb.WriteString("assigned into a deep type expecting nonnil element type") + sb.WriteString(d.AssignmentStr) + return sb.String() } // ParamAssignDeep is when a value flows to a point where it is assigned deeply into a function parameter @@ -974,18 +1374,32 @@ func (p *ParamAssignDeep) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (p *ParamAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *p + copyConsumer.TriggerIfDeepNonNil = p.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this ParamAssignDeep as a Prestring func (p *ParamAssignDeep) Prestring() Prestring { - return ParamAssignDeepPrestring{p.Ann.(*ParamAnnotationKey).MinimalString()} + return ParamAssignDeepPrestring{ + p.Ann.(*ParamAnnotationKey).MinimalString(), + p.assignmentFlow.String(), + } } // ParamAssignDeepPrestring is a Prestring storing the needed information to compactly encode a ParamAssignDeep type ParamAssignDeepPrestring struct { - ParamName string + ParamName string + AssignmentStr string } func (p ParamAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into parameter %s", p.ParamName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into parameter %s", p.ParamName)) + sb.WriteString(p.AssignmentStr) + return sb.String() } // FuncRetAssignDeep is when a value flows to a point where it is assigned deeply into a function return @@ -1001,23 +1415,35 @@ func (f *FuncRetAssignDeep) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *FuncRetAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfDeepNonNil = f.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this FuncRetAssignDeep as a Prestring func (f *FuncRetAssignDeep) Prestring() Prestring { retAnn := f.Ann.(*RetAnnotationKey) return FuncRetAssignDeepPrestring{ retAnn.FuncDecl.Name(), retAnn.RetNum, + f.assignmentFlow.String(), } } // FuncRetAssignDeepPrestring is a Prestring storing the needed information to compactly encode a FuncRetAssignDeep type FuncRetAssignDeepPrestring struct { - FuncName string - RetNum int + FuncName string + RetNum int + AssignmentStr string } func (f FuncRetAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into the result %d of `%s()`", f.RetNum, f.FuncName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into the result %d of `%s()`", f.RetNum, f.FuncName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // VariadicParamAssignDeep is when a value flows to a point where it is assigned deeply into a variadic @@ -1034,21 +1460,33 @@ func (v *VariadicParamAssignDeep) equals(other ConsumingAnnotationTrigger) bool return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (v *VariadicParamAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *v + copyConsumer.TriggerIfNonNil = v.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this VariadicParamAssignDeep as a Prestring func (v *VariadicParamAssignDeep) Prestring() Prestring { paramAnn := v.Ann.(*ParamAnnotationKey) return VariadicParamAssignDeepPrestring{ - ParamName: paramAnn.MinimalString(), + ParamName: paramAnn.MinimalString(), + AssignmentStr: v.assignmentFlow.String(), } } // VariadicParamAssignDeepPrestring is a Prestring storing the needed information to compactly encode a VariadicParamAssignDeep type VariadicParamAssignDeepPrestring struct { - ParamName string + ParamName string + AssignmentStr string } func (v VariadicParamAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into variadic parameter `%s`", v.ParamName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into variadic parameter `%s`", v.ParamName)) + sb.WriteString(v.AssignmentStr) + return sb.String() } // FieldAssignDeep is when a value flows to a point where it is assigned deeply into a field @@ -1064,19 +1502,33 @@ func (f *FieldAssignDeep) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *FieldAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfDeepNonNil = f.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this FieldAssignDeep as a Prestring func (f *FieldAssignDeep) Prestring() Prestring { fldAnn := f.Ann.(*FieldAnnotationKey) - return FieldAssignDeepPrestring{fldAnn.FieldDecl.Name()} + return FieldAssignDeepPrestring{ + fldAnn.FieldDecl.Name(), + f.assignmentFlow.String(), + } } // FieldAssignDeepPrestring is a Prestring storing the needed information to compactly encode a FieldAssignDeep type FieldAssignDeepPrestring struct { - FldName string + FldName string + AssignmentStr string } func (f FieldAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into field `%s`", f.FldName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into field `%s`", f.FldName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // GlobalVarAssignDeep is when a value flows to a point where it is assigned deeply into a global variable @@ -1092,19 +1544,33 @@ func (g *GlobalVarAssignDeep) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (g *GlobalVarAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *g + copyConsumer.TriggerIfDeepNonNil = g.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this GlobalVarAssignDeep as a Prestring func (g *GlobalVarAssignDeep) Prestring() Prestring { varAnn := g.Ann.(*GlobalVarAnnotationKey) - return GlobalVarAssignDeepPrestring{varAnn.VarDecl.Name()} + return GlobalVarAssignDeepPrestring{ + varAnn.VarDecl.Name(), + g.assignmentFlow.String(), + } } // GlobalVarAssignDeepPrestring is a Prestring storing the needed information to compactly encode a GlobalVarAssignDeep type GlobalVarAssignDeepPrestring struct { - VarName string + VarName string + AssignmentStr string } func (g GlobalVarAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into global variable `%s`", g.VarName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into global variable `%s`", g.VarName)) + sb.WriteString(g.AssignmentStr) + return sb.String() } // ChanAccess is when a channel is accessed for sending, and thus must be non-nil @@ -1120,16 +1586,30 @@ func (c *ChanAccess) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (c *ChanAccess) Copy() ConsumingAnnotationTrigger { + copyConsumer := *c + copyConsumer.ConsumeTriggerTautology = c.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this MapWrittenTo as a Prestring func (c *ChanAccess) Prestring() Prestring { - return ChanAccessPrestring{} + return ChanAccessPrestring{ + AssignmentStr: c.assignmentFlow.String(), + } } // ChanAccessPrestring is a Prestring storing the needed information to compactly encode a ChanAccess -type ChanAccessPrestring struct{} +type ChanAccessPrestring struct { + AssignmentStr string +} -func (ChanAccessPrestring) String() string { - return "uninitialized; nil channel accessed" +func (c ChanAccessPrestring) String() string { + var sb strings.Builder + sb.WriteString("uninitialized; nil channel accessed") + sb.WriteString(c.AssignmentStr) + return sb.String() } // LocalVarAssignDeep is when a value flows to a point where it is assigned deeply into a local variable of deeply nonnil type @@ -1146,18 +1626,32 @@ func (l *LocalVarAssignDeep) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (l *LocalVarAssignDeep) Copy() ConsumingAnnotationTrigger { + copyConsumer := *l + copyConsumer.ConsumeTriggerTautology = l.ConsumeTriggerTautology.Copy().(*ConsumeTriggerTautology) + return ©Consumer +} + // Prestring returns this LocalVarAssignDeep as a Prestring func (l *LocalVarAssignDeep) Prestring() Prestring { - return LocalVarAssignDeepPrestring{VarName: l.LocalVar.Name()} + return LocalVarAssignDeepPrestring{ + VarName: l.LocalVar.Name(), + AssignmentStr: l.assignmentFlow.String(), + } } // LocalVarAssignDeepPrestring is a Prestring storing the needed information to compactly encode a LocalVarAssignDeep type LocalVarAssignDeepPrestring struct { - VarName string + VarName string + AssignmentStr string } func (l LocalVarAssignDeepPrestring) String() string { - return fmt.Sprintf("assigned deeply into local variable `%s`", l.VarName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("assigned deeply into local variable `%s`", l.VarName)) + sb.WriteString(l.AssignmentStr) + return sb.String() } // ChanSend is when a value flows to a point where it is sent to a channel @@ -1173,19 +1667,33 @@ func (c *ChanSend) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (c *ChanSend) Copy() ConsumingAnnotationTrigger { + copyConsumer := *c + copyConsumer.TriggerIfDeepNonNil = c.TriggerIfDeepNonNil.Copy().(*TriggerIfDeepNonNil) + return ©Consumer +} + // Prestring returns this ChanSend as a Prestring func (c *ChanSend) Prestring() Prestring { typeAnn := c.Ann.(*TypeNameAnnotationKey) - return ChanSendPrestring{typeAnn.TypeDecl.Name()} + return ChanSendPrestring{ + typeAnn.TypeDecl.Name(), + c.assignmentFlow.String(), + } } // ChanSendPrestring is a Prestring storing the needed information to compactly encode a ChanSend type ChanSendPrestring struct { - TypeName string + TypeName string + AssignmentStr string } func (c ChanSendPrestring) String() string { - return fmt.Sprintf("sent to channel of deeply nonnil type `%s`", c.TypeName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("sent to channel of deeply nonnil type `%s`", c.TypeName)) + sb.WriteString(c.AssignmentStr) + return sb.String() } // FldEscape is when a nilable value flows through a field of a struct that escapes. @@ -1208,21 +1716,33 @@ func (f *FldEscape) equals(other ConsumingAnnotationTrigger) bool { return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (f *FldEscape) Copy() ConsumingAnnotationTrigger { + copyConsumer := *f + copyConsumer.TriggerIfNonNil = f.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this FldEscape as a Prestring func (f *FldEscape) Prestring() Prestring { ann := f.Ann.(*EscapeFieldAnnotationKey) return FldEscapePrestring{ - FieldName: ann.FieldDecl.Name(), + FieldName: ann.FieldDecl.Name(), + AssignmentStr: f.assignmentFlow.String(), } } // FldEscapePrestring is a Prestring storing the needed information to compactly encode a FldEscape type FldEscapePrestring struct { - FieldName string + FieldName string + AssignmentStr string } func (f FldEscapePrestring) String() string { - return fmt.Sprintf("field `%s` escaped out of our analysis scope (presumed nilable)", f.FieldName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("field `%s` escaped out of our analysis scope (presumed nilable)", f.FieldName)) + sb.WriteString(f.AssignmentStr) + return sb.String() } // UseAsNonErrorRetDependentOnErrorRetNilability is when a value flows to a point where it is returned from an error returning function @@ -1243,6 +1763,13 @@ func (u *UseAsNonErrorRetDependentOnErrorRetNilability) equals(other ConsumingAn return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (u *UseAsNonErrorRetDependentOnErrorRetNilability) Copy() ConsumingAnnotationTrigger { + copyConsumer := *u + copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this UseAsNonErrorRetDependentOnErrorRetNilability as a Prestring func (u *UseAsNonErrorRetDependentOnErrorRetNilability) Prestring() Prestring { retAnn := u.Ann.(*RetAnnotationKey) @@ -1252,6 +1779,7 @@ func (u *UseAsNonErrorRetDependentOnErrorRetNilability) Prestring() Prestring { retAnn.FuncDecl.Type().(*types.Signature).Results().At(retAnn.RetNum).Name(), retAnn.FuncDecl.Type().(*types.Signature).Results().Len() - 1, u.IsNamedReturn, + u.assignmentFlow.String(), } } @@ -1262,6 +1790,7 @@ type UseAsNonErrorRetDependentOnErrorRetNilabilityPrestring struct { RetName string ErrRetNum int IsNamedReturn bool + AssignmentStr string } func (u UseAsNonErrorRetDependentOnErrorRetNilabilityPrestring) String() string { @@ -1270,8 +1799,11 @@ func (u UseAsNonErrorRetDependentOnErrorRetNilabilityPrestring) String() string via = fmt.Sprintf(" via named return `%s`", u.RetName) } - return fmt.Sprintf("returned from `%s()`%s in position %d when the error return in position %d is not guaranteed to be non-nil through all paths", - u.FuncName, via, u.RetNum, u.ErrRetNum) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("returned from `%s()`%s in position %d when the error return in position %d is not guaranteed to be non-nil through all paths", + u.FuncName, via, u.RetNum, u.ErrRetNum)) + sb.WriteString(u.AssignmentStr) + return sb.String() } // overriding position value to point to the raw return statement, which is the source of the potential error @@ -1300,6 +1832,13 @@ func (u *UseAsErrorRetWithNilabilityUnknown) equals(other ConsumingAnnotationTri return false } +// Copy returns a deep copy of this ConsumingAnnotationTrigger +func (u *UseAsErrorRetWithNilabilityUnknown) Copy() ConsumingAnnotationTrigger { + copyConsumer := *u + copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil) + return ©Consumer +} + // Prestring returns this UseAsErrorRetWithNilabilityUnknown as a Prestring func (u *UseAsErrorRetWithNilabilityUnknown) Prestring() Prestring { retAnn := u.Ann.(*RetAnnotationKey) @@ -1308,6 +1847,7 @@ func (u *UseAsErrorRetWithNilabilityUnknown) Prestring() Prestring { retAnn.RetNum, u.IsNamedReturn, retAnn.FuncDecl.Type().(*types.Signature).Results().At(retAnn.RetNum).Name(), + u.assignmentFlow.String(), } } @@ -1317,13 +1857,18 @@ type UseAsErrorRetWithNilabilityUnknownPrestring struct { RetNum int IsNamedReturn bool RetName string + AssignmentStr string } func (u UseAsErrorRetWithNilabilityUnknownPrestring) String() string { + var sb strings.Builder if u.IsNamedReturn { - return fmt.Sprintf("found in at least one path of `%s()` for named return `%s` in position %d", u.FuncName, u.RetName, u.RetNum) + sb.WriteString(fmt.Sprintf("found in at least one path of `%s()` for named return `%s` in position %d", u.FuncName, u.RetName, u.RetNum)) + } else { + sb.WriteString(fmt.Sprintf("found in at least one path of `%s()` for return in position %d", u.FuncName, u.RetNum)) } - return fmt.Sprintf("found in at least one path of `%s()` for return in position %d", u.FuncName, u.RetNum) + sb.WriteString(u.AssignmentStr) + return sb.String() } // overriding position value to point to the raw return statement, which is the source of the potential error @@ -1398,6 +1943,14 @@ func (c *ConsumeTrigger) equals(c2 *ConsumeTrigger) bool { } +// Copy returns a deep copy of the ConsumeTrigger +func (c *ConsumeTrigger) Copy() *ConsumeTrigger { + copyTrigger := *c + copyTrigger.Annotation = c.Annotation.Copy() + copyTrigger.Guards = c.Guards.Copy() + return ©Trigger +} + // Pos returns the source position (e.g., line) of the consumer's expression. In special cases, such as named return, it // returns the position of the stored return AST node func (c *ConsumeTrigger) Pos() token.Pos { @@ -1420,7 +1973,7 @@ func MergeConsumeTriggerSlices(left, right []*ConsumeTrigger) []*ConsumeTrigger // intersect guard sets - if a guard isn't present in both branches it can't // be considered present before the branch out[i] = &ConsumeTrigger{ - Annotation: outTrigger.Annotation, + Annotation: outTrigger.Annotation.Copy(), Expr: outTrigger.Expr, Guards: outTrigger.Guards.Intersection(trigger.Guards), GuardMatched: outTrigger.GuardMatched && trigger.GuardMatched, @@ -1448,7 +2001,7 @@ func ConsumeTriggerSliceAsGuarded(slice []*ConsumeTrigger, guards ...util.GuardN var out []*ConsumeTrigger for _, trigger := range slice { out = append(out, &ConsumeTrigger{ - Annotation: trigger.Annotation, + Annotation: trigger.Annotation.Copy(), Expr: trigger.Expr, Guards: trigger.Guards.Copy().Add(guards...), }) diff --git a/annotation/consume_trigger_test.go b/annotation/consume_trigger_test.go index 367628bd..f00a489a 100644 --- a/annotation/consume_trigger_test.go +++ b/annotation/consume_trigger_test.go @@ -17,62 +17,76 @@ package annotation import ( "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) -type ConsumingAnnotationTriggerTestSuite struct { - EqualsTestSuite +const _interfaceNameConsumingAnnotationTrigger = "ConsumingAnnotationTrigger" + +// initStructsConsumingAnnotationTrigger initializes all structs that implement the ConsumingAnnotationTrigger interface +var initStructsConsumingAnnotationTrigger = []any{ + &TriggerIfNonNil{Ann: newMockKey()}, + &TriggerIfDeepNonNil{Ann: newMockKey()}, + &ConsumeTriggerTautology{}, + &PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, + &ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}}, + &FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, + &UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}}, } -func (s *ConsumingAnnotationTriggerTestSuite) SetupTest() { - s.interfaceName = "ConsumingAnnotationTrigger" +// ConsumingAnnotationTriggerEqualsTestSuite tests for the `equals` method of all the structs that implement +// the `ConsumingAnnotationTrigger` interface. +type ConsumingAnnotationTriggerEqualsTestSuite struct { + EqualsTestSuite +} - mockedKey := new(mockKey) - mockedKey.On("equals", mock.Anything).Return(true) +func (s *ConsumingAnnotationTriggerEqualsTestSuite) SetupTest() { + s.interfaceName = _interfaceNameConsumingAnnotationTrigger + s.initStructs = initStructsConsumingAnnotationTrigger +} - // initialize all structs that implement ConsumingAnnotationTrigger - s.initStructs = []any{ - &TriggerIfNonNil{Ann: mockedKey}, - &TriggerIfDeepNonNil{Ann: mockedKey}, - &ConsumeTriggerTautology{}, - &PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}}, - &ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}}, - &FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - &UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}}, - } +func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) { + t.Parallel() + suite.Run(t, new(ConsumingAnnotationTriggerEqualsTestSuite)) } -// TestConsumingAnnotationTriggerEqualsSuite runs the test suite for the `equals` method of all the structs that implement +// ConsumingAnnotationTriggerCopyTestSuite tests for the `copy` method of all the structs that implement // the `ConsumingAnnotationTrigger` interface. -func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) { +type ConsumingAnnotationTriggerCopyTestSuite struct { + CopyTestSuite +} + +func (s *ConsumingAnnotationTriggerCopyTestSuite) SetupTest() { + s.interfaceName = _interfaceNameConsumingAnnotationTrigger + s.initStructs = initStructsConsumingAnnotationTrigger +} +func TestConsumingAnnotationTriggerCopySuite(t *testing.T) { t.Parallel() - suite.Run(t, new(ConsumingAnnotationTriggerTestSuite)) + suite.Run(t, new(ConsumingAnnotationTriggerCopyTestSuite)) } diff --git a/annotation/copy_test.go b/annotation/copy_test.go new file mode 100644 index 00000000..f15c5076 --- /dev/null +++ b/annotation/copy_test.go @@ -0,0 +1,121 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package annotation + +import ( + "fmt" + "reflect" + "strings" + + "github.com/stretchr/testify/suite" +) + +type CopyTestSuite struct { + suite.Suite + initStructs []any + interfaceName string + packagePath string +} + +type objInfo struct { + addr string + numFields int + typ reflect.Type +} + +func newObjInfo(addr string, numFields int, typ reflect.Type) objInfo { + return objInfo{ + addr: addr, + numFields: numFields, + typ: typ, + } +} + +// getObjInfo is a helper function that returns a map of struct and field names to their objInfo. +// The key is in the format of `struct_` or `fld_.`. +func getObjInfo(obj any) map[string]objInfo { + ptr := make(map[string]objInfo) + + val := reflect.ValueOf(obj).Elem() + ptr[fmt.Sprintf("struct_%s", val.Type().Name())] = newObjInfo(fmt.Sprintf("%p", val.Addr().Interface()), val.NumField(), val.Type()) + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + key := fmt.Sprintf("fld_%s.%s", val.Type().Name(), val.Type().Field(i).Name) + if field.Kind() == reflect.Ptr { + if !field.IsZero() { + ptr[key] = newObjInfo(fmt.Sprintf("%p", field.Interface()), field.Elem().NumField(), field.Elem().Type()) + } + } else if field.Kind() == reflect.Interface && !field.IsNil() { + // %p cannot be used directly with a reflect.Value, so we need to extract the underlying value first. + interfaceValue := field.Interface() + underlyingValue := reflect.ValueOf(interfaceValue).Elem() + ptr[key] = newObjInfo(fmt.Sprintf("%p", underlyingValue.Addr().Interface()), underlyingValue.NumField(), underlyingValue.Type()) + } else { + ptr[key] = newObjInfo("", 0, field.Type()) + } + } + return ptr +} + +// This test checks that the `Copy` method implementations perform a deep copy, i.e., copies the values but generates +// different pointer addresses for the copied struct and its fields. +// Note that here we cannot use `reflect.DeepEqual` to compare the original and copied structs because reflection +// does not work well with fields with nested struct pointers, giving incorrect results. +// Therefore, we compare the original and copied structs along with their fields for: +// - type +// - number of fields +// - pointer address (if the field is a struct and has at least one field) +func (s *CopyTestSuite) TestCopy() { + var expectedObjs, actualObjs map[string]objInfo + + for _, initStruct := range s.initStructs { + var copied any + expectedObjs = getObjInfo(initStruct) + + switch t := initStruct.(type) { + case ConsumingAnnotationTrigger: + copied = t.Copy() + actualObjs = getObjInfo(copied) + case Key: + copied = t.copy() + actualObjs = getObjInfo(copied) + default: + s.Failf("unknown type", "unknown type %T", t) + } + + for expectedKey, expectedObj := range expectedObjs { + actualObj, ok := actualObjs[expectedKey] + s.True(ok, "key `%s` should exist in copied struct object", expectedKey) + s.Equal(expectedObj.typ, actualObj.typ, "key `%s` should have the same type after deep copying", expectedKey) + s.Equal(expectedObj.numFields, actualObj.numFields, "key `%s` should have the same number of fields after deep copying", expectedKey) + + // Note that Go optimizes the memory allocation of pointers to structs. The pointer address for structs with + // no fields will be the same. E.g., consider struct `S` with no fields, then `s1 := &S{}, s2 := &S{}; + // fmt.Printf("%p %p", s1, s2)` will print the same address. Therefore, we only add the pointer address of a struct + // if it has at least one field. The reason for this being that currently, the use of this helper function is used only in + // the `CopyTestSuite` to check that the `Copy` method implementations perform a deep copy, i.e., generates different + // pointer addresses for the copied struct and its fields. We may want to modify this behavior in the future, if needed. + if expectedObj.addr != "" && actualObj.addr != "" && expectedObj.numFields > 0 && actualObj.numFields > 0 { + s.NotEqual(expectedObj.addr, actualObj.addr, "key `%s` should not have the same pointer value after deep copying", expectedKey) + } + } + } +} + +// Similar to EqualsTestSuite, this test serves as a sanity check to ensure that all the implemented consumer structs +// are tested in this file. The test fails if there are any structs that are found missing from the expected list. +func (s *CopyTestSuite) TestStructsChecked() { + missedStructs := structsCheckedTestHelper(s.interfaceName, s.packagePath, s.initStructs) + s.Equalf(0, len(missedStructs), "the following structs were not tested: [`%s`]", strings.Join(missedStructs, "`, `")) +} diff --git a/annotation/equals_test.go b/annotation/equals_test.go index e8b04d08..6ab9e5d5 100644 --- a/annotation/equals_test.go +++ b/annotation/equals_test.go @@ -15,208 +15,14 @@ package annotation import ( - "go/ast" - "go/parser" - "go/token" - "go/types" - "reflect" "strings" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" - "go.uber.org/nilaway/util" - "golang.org/x/exp/maps" - "golang.org/x/tools/go/packages" ) // This test file tests the implementation of the `equals` method defined for the interfaces `ConsumingAnnotationTrigger`, // `ProducingAnnotationTrigger` and `Key`. -// Below are the helper utilities used in the tests, such as mock implementations of the interfaces and utility functions. - -// mockKey is a mock implementation of the Key interface -type mockKey struct { - mock.Mock -} - -func (m *mockKey) Lookup(m2 Map) (Val, bool) { - args := m.Called(m2) - return args.Get(0).(Val), args.Bool(1) -} - -func (m *mockKey) Object() types.Object { - args := m.Called() - return args.Get(0).(types.Object) -} - -func (m *mockKey) equals(other Key) bool { - args := m.Called(other) - return args.Bool(0) -} - -// mockProducingAnnotationTrigger is a mock implementation of the ProducingAnnotationTrigger interface -type mockProducingAnnotationTrigger struct { - mock.Mock -} - -func (m *mockProducingAnnotationTrigger) CheckProduce(m2 Map) bool { - args := m.Called(m2) - return args.Bool(0) -} - -func (m *mockProducingAnnotationTrigger) NeedsGuardMatch() bool { - args := m.Called() - return args.Bool(0) -} - -func (m *mockProducingAnnotationTrigger) SetNeedsGuard(b bool) { - m.Called(b) -} - -func (m *mockProducingAnnotationTrigger) Prestring() Prestring { - args := m.Called() - return args.Get(0).(Prestring) -} - -func (m *mockProducingAnnotationTrigger) Kind() TriggerKind { - args := m.Called() - return args.Get(0).(TriggerKind) -} - -func (m *mockProducingAnnotationTrigger) UnderlyingSite() Key { - args := m.Called() - return args.Get(0).(Key) -} - -func (m *mockProducingAnnotationTrigger) equals(other ProducingAnnotationTrigger) bool { - args := m.Called(other) - return args.Bool(0) -} - -// getImplementedMethods is a helper function that returns all the methods implemented by the struct "t" -func getImplementedMethods(t *types.Named) []*types.Func { - visitedMethods := make(map[string]*types.Func) // helps in only storing the latest overridden implementation of a method - visitedStructs := make(map[*types.Struct]bool) // helps in avoiding infinite recursion if there is a cycle in the struct embedding - collectMethods(t, visitedMethods, visitedStructs) - return maps.Values(visitedMethods) -} - -// collectMethods is a helper function that recursively collects all `methods` implemented by the struct `t`. -// Methods inherited from the embedded and anonymous fields of `t` are collected in a DFS manner. In case of overriding, -// only the overridden implementation of the method is stored with the help of `visitedMethodNames`. For example, -// consider the following illustrative example, and the collected methods at different casting sites. -// ``` -// type S struct { ... } func (s *S) foo() { ... } s := &S{} // methods = [foo()] -// type T struct { S } func (t *T) bar() { ... } t := &T{} // methods = [bar()] -// type U struct { T } u := &U{} // methods = [foo(), bar()] -// ``` -func collectMethods(t *types.Named, visitedMethods map[string]*types.Func, visitedStructs map[*types.Struct]bool) { - for i := 0; i < t.NumMethods(); i++ { - m := t.Method(i) - if _, ok := visitedMethods[m.Name()]; !ok { - visitedMethods[m.Name()] = m - } - } - - // collect methods from embedded fields - if s := util.TypeAsDeeplyStruct(t); s != nil && !visitedStructs[s] { - visitedStructs[s] = true - for i := 0; i < s.NumFields(); i++ { - f := s.Field(i) - if f.Embedded() { - if n, ok := util.UnwrapPtr(f.Type()).(*types.Named); ok { - collectMethods(n, visitedMethods, visitedStructs) - } - } - } - } -} - -// structsImplementingInterface is a helper function that returns all the struct names implementing the given interface -// in the given package recursively -func structsImplementingInterface(interfaceName string, packageName ...string) map[string]bool { - structs := make(map[string]bool) - - // if no package name is provided, default to using the current directory - if len(packageName) == 0 { - packageName = []string{"."} - } - - cfg := &packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | - packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, - } - - for _, p := range packageName { - pkgs, err := packages.Load(cfg, p) - if err != nil { - panic(err) - } - if len(pkgs) == 0 { - panic("no packages found") - } - - for _, pkg := range pkgs { - // scan the packages to find the interface and get its *types.Interface object - obj := pkgs[0].Types.Scope().Lookup(interfaceName) - if obj == nil { - continue - } - interfaceObj, ok := obj.Type().Underlying().(*types.Interface) - if !ok { - continue - } - - // iterate over all Go files in the package to find the structs implementing the interface - for _, filepath := range pkg.GoFiles { - fset := token.NewFileSet() - node, err := parser.ParseFile(fset, filepath, nil, parser.AllErrors) - if err != nil { - panic(err) - } - - ast.Inspect(node, func(n ast.Node) bool { - if typeSpec, ok := n.(*ast.TypeSpec); ok { - if _, ok := typeSpec.Type.(*ast.StructType); ok { - sObj := pkg.Types.Scope().Lookup(typeSpec.Name.Name) - if sObj == nil { - return true - } - sType, ok := sObj.Type().(*types.Named) - if !ok { - return true - } - - structMethods := getImplementedMethods(sType) - if interfaceObj.NumMethods() > len(structMethods) { - return true - } - - // compare the methods of the interface and the struct, increment `match` if the method names match - match := 0 - for i := 0; i < interfaceObj.NumMethods(); i++ { - iMethod := interfaceObj.Method(i) - for _, sMethod := range structMethods { - if iMethod.Name() == sMethod.Name() { - match++ - } - } - } - if match == interfaceObj.NumMethods() { - // we have found a struct that implements the interface - structs[typeSpec.Name.Name] = true - } - } - } - return true - }) - } - } - } - return structs -} - -// EqualsTestSuite defines the test suite for the `equals` method. type EqualsTestSuite struct { suite.Suite initStructs []any @@ -281,21 +87,6 @@ func (s *EqualsTestSuite) TestEqualsFalse() { // using `structsImplementingInterface()`, and finds the actual list of consumer structs that are tested in the // governing test case. The test fails if there are any structs that are missing from the expected list. func (s *EqualsTestSuite) TestStructsChecked() { - expected := structsImplementingInterface(s.interfaceName, s.packagePath) - s.NotEmpty(expected, "no structs found implementing `%s` interface", s.interfaceName) - - actual := make(map[string]bool) - for _, initStruct := range s.initStructs { - actual[reflect.TypeOf(initStruct).Elem().Name()] = true - } - - // compare expected and actual, and find structs that were not tested - var missedStructs []string - for structName := range expected { - if !actual[structName] { - missedStructs = append(missedStructs, structName) - } - } - // if there are any structs that were not tested, fail the test and print the list of structs + missedStructs := structsCheckedTestHelper(s.interfaceName, s.packagePath, s.initStructs) s.Equalf(0, len(missedStructs), "the following structs were not tested: [`%s`]", strings.Join(missedStructs, "`, `")) } diff --git a/annotation/helper_test.go b/annotation/helper_test.go new file mode 100644 index 00000000..12375220 --- /dev/null +++ b/annotation/helper_test.go @@ -0,0 +1,250 @@ +// Copyright (c) 2023 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package annotation + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "reflect" + + "github.com/stretchr/testify/mock" + "go.uber.org/nilaway/util" + "golang.org/x/exp/maps" + "golang.org/x/tools/go/packages" +) + +// mockKey is a mock implementation of the Key interface +type mockKey struct { + mock.Mock +} + +func (m *mockKey) Lookup(m2 Map) (Val, bool) { + args := m.Called(m2) + return args.Get(0).(Val), args.Bool(1) +} + +func (m *mockKey) Object() types.Object { + args := m.Called() + return args.Get(0).(types.Object) +} + +func (m *mockKey) equals(other Key) bool { + args := m.Called(other) + return args.Bool(0) +} + +func (m *mockKey) copy() Key { + args := m.Called() + return args.Get(0).(Key) +} + +func newMockKey() *mockKey { + mockedKey := new(mockKey) + mockedKey.ExpectedCalls = nil + mockedKey.On("equals", mock.Anything).Return(true) + + copiedMockKey := new(mockKey) + mockedKey.ExpectedCalls = nil + mockedKey.On("equals", mock.Anything).Return(true) + + mockedKey.On("copy").Return(copiedMockKey) + return mockedKey +} + +// mockProducingAnnotationTrigger is a mock implementation of the ProducingAnnotationTrigger interface +type mockProducingAnnotationTrigger struct { + mock.Mock +} + +func (m *mockProducingAnnotationTrigger) CheckProduce(m2 Map) bool { + args := m.Called(m2) + return args.Bool(0) +} + +func (m *mockProducingAnnotationTrigger) NeedsGuardMatch() bool { + args := m.Called() + return args.Bool(0) +} + +func (m *mockProducingAnnotationTrigger) SetNeedsGuard(b bool) { + m.Called(b) +} + +func (m *mockProducingAnnotationTrigger) Prestring() Prestring { + args := m.Called() + return args.Get(0).(Prestring) +} + +func (m *mockProducingAnnotationTrigger) Kind() TriggerKind { + args := m.Called() + return args.Get(0).(TriggerKind) +} + +func (m *mockProducingAnnotationTrigger) UnderlyingSite() Key { + args := m.Called() + return args.Get(0).(Key) +} + +func (m *mockProducingAnnotationTrigger) equals(other ProducingAnnotationTrigger) bool { + args := m.Called(other) + return args.Bool(0) +} + +// getImplementedMethods is a helper function that returns all the methods implemented by the struct "t" +func getImplementedMethods(t *types.Named) []*types.Func { + visitedMethods := make(map[string]*types.Func) // helps in only storing the latest overridden implementation of a method + visitedStructs := make(map[*types.Struct]bool) // helps in avoiding infinite recursion if there is a cycle in the struct embedding + collectMethods(t, visitedMethods, visitedStructs) + return maps.Values(visitedMethods) +} + +// collectMethods is a helper function that recursively collects all `methods` implemented by the struct `t`. +// Methods inherited from the embedded and anonymous fields of `t` are collected in a DFS manner. In case of overriding, +// only the overridden implementation of the method is stored with the help of `visitedMethodNames`. For example, +// consider the following illustrative example, and the collected methods at different casting sites. +// ``` +// type S struct { ... } func (s *S) foo() { ... } s := &S{} // methods = [foo()] +// type T struct { S } func (t *T) bar() { ... } t := &T{} // methods = [bar()] +// type U struct { T } u := &U{} // methods = [foo(), bar()] +// ``` +func collectMethods(t *types.Named, visitedMethods map[string]*types.Func, visitedStructs map[*types.Struct]bool) { + for i := 0; i < t.NumMethods(); i++ { + m := t.Method(i) + if _, ok := visitedMethods[m.Name()]; !ok { + visitedMethods[m.Name()] = m + } + } + + // collect methods from embedded fields + if s := util.TypeAsDeeplyStruct(t); s != nil && !visitedStructs[s] { + visitedStructs[s] = true + for i := 0; i < s.NumFields(); i++ { + f := s.Field(i) + if f.Embedded() { + if n, ok := util.UnwrapPtr(f.Type()).(*types.Named); ok { + collectMethods(n, visitedMethods, visitedStructs) + } + } + } + } +} + +// structsImplementingInterface is a helper function that returns all the struct names implementing the given interface +// in the given package recursively +func structsImplementingInterface(interfaceName string, packageName ...string) map[string]bool { + structs := make(map[string]bool) + + // if no package name is provided, default to using the current directory + if len(packageName) == 0 { + packageName = []string{"."} + } + + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | + packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, + } + + for _, p := range packageName { + pkgs, err := packages.Load(cfg, p) + if err != nil { + panic(err) + } + if len(pkgs) == 0 { + panic("no packages found") + } + + for _, pkg := range pkgs { + // scan the packages to find the interface and get its *types.Interface object + obj := pkgs[0].Types.Scope().Lookup(interfaceName) + if obj == nil { + continue + } + interfaceObj, ok := obj.Type().Underlying().(*types.Interface) + if !ok { + continue + } + + // iterate over all Go files in the package to find the structs implementing the interface + for _, filepath := range pkg.GoFiles { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, filepath, nil, parser.AllErrors) + if err != nil { + panic(err) + } + + ast.Inspect(node, func(n ast.Node) bool { + if typeSpec, ok := n.(*ast.TypeSpec); ok { + if _, ok := typeSpec.Type.(*ast.StructType); ok { + sObj := pkg.Types.Scope().Lookup(typeSpec.Name.Name) + if sObj == nil { + return true + } + sType, ok := sObj.Type().(*types.Named) + if !ok { + return true + } + + structMethods := getImplementedMethods(sType) + if interfaceObj.NumMethods() > len(structMethods) { + return true + } + + // compare the methods of the interface and the struct, increment `match` if the method names match + match := 0 + for i := 0; i < interfaceObj.NumMethods(); i++ { + iMethod := interfaceObj.Method(i) + for _, sMethod := range structMethods { + if iMethod.Name() == sMethod.Name() { + match++ + } + } + } + if match == interfaceObj.NumMethods() { + // we have found a struct that implements the interface + structs[typeSpec.Name.Name] = true + } + } + } + return true + }) + } + } + } + return structs +} + +func structsCheckedTestHelper(interfaceName string, packagePath string, initStructs []any) []string { + expected := structsImplementingInterface(interfaceName, packagePath) + if len(expected) == 0 { + panic(fmt.Sprintf("no structs found implementing `%s` interface", interfaceName)) + } + + actual := make(map[string]bool) + for _, initStruct := range initStructs { + actual[reflect.TypeOf(initStruct).Elem().Name()] = true + } + + // compare expected and actual, and find structs that were not tested + var missedStructs []string + for structName := range expected { + if !actual[structName] { + missedStructs = append(missedStructs, structName) + } + } + return missedStructs +} diff --git a/annotation/key.go b/annotation/key.go index a39a7932..ab816047 100644 --- a/annotation/key.go +++ b/annotation/key.go @@ -42,6 +42,9 @@ type Key interface { // equals returns true if the passed key is equal to this key equals(Key) bool + + // copy returns a deep copy of this key + copy() Key } // FieldAnnotationKey allows the Lookup of a field's Annotation in the Annotation map @@ -70,6 +73,11 @@ func (k *FieldAnnotationKey) equals(other Key) bool { return false } +func (k *FieldAnnotationKey) copy() Key { + copyKey := *k + return ©Key +} + func (k *FieldAnnotationKey) String() string { return fmt.Sprintf("Field %s", k.FieldDecl.Name()) } @@ -115,6 +123,11 @@ func (pk *CallSiteParamAnnotationKey) equals(other Key) bool { return false } +func (pk *CallSiteParamAnnotationKey) copy() Key { + copyKey := *pk + return ©Key +} + func (pk *CallSiteParamAnnotationKey) String() string { argname := "" if pk.ParamName() != nil { @@ -241,6 +254,11 @@ func (pk *ParamAnnotationKey) equals(other Key) bool { return false } +func (pk *ParamAnnotationKey) copy() Key { + copyKey := *pk + return ©Key +} + func (pk *ParamAnnotationKey) String() string { argname := "" if pk.ParamName() != nil { @@ -302,6 +320,11 @@ func (rk *CallSiteRetAnnotationKey) equals(other Key) bool { return false } +func (rk *CallSiteRetAnnotationKey) copy() Key { + copyKey := *rk + return ©Key +} + func (rk *CallSiteRetAnnotationKey) String() string { return fmt.Sprintf("Result %d of Function %s at Location %v", rk.RetNum, rk.FuncDecl.Name(), rk.Location) @@ -344,6 +367,11 @@ func (rk *RetAnnotationKey) equals(other Key) bool { return false } +func (rk *RetAnnotationKey) copy() Key { + copyKey := *rk + return ©Key +} + func (rk *RetAnnotationKey) String() string { return fmt.Sprintf("Result %d of Function %s", rk.RetNum, rk.FuncDecl.Name()) @@ -383,6 +411,11 @@ func (tk *TypeNameAnnotationKey) equals(other Key) bool { return false } +func (tk *TypeNameAnnotationKey) copy() Key { + copyKey := *tk + return ©Key +} + func (tk *TypeNameAnnotationKey) String() string { return fmt.Sprintf("Type %s", tk.TypeDecl.Name()) } @@ -413,6 +446,11 @@ func (gk *GlobalVarAnnotationKey) equals(other Key) bool { return false } +func (gk *GlobalVarAnnotationKey) copy() Key { + copyKey := *gk + return ©Key +} + func (gk *GlobalVarAnnotationKey) String() string { return fmt.Sprintf("Global Variable %s", gk.VarDecl.Name()) } @@ -449,6 +487,11 @@ func (rf *RetFieldAnnotationKey) equals(other Key) bool { return false } +func (rf *RetFieldAnnotationKey) copy() Key { + copyKey := *rf + return ©Key +} + // String returns a string representation of this annotation key func (rf *RetFieldAnnotationKey) String() string { // If the function has a receiver, we add info in the error message @@ -500,6 +543,11 @@ func (ek *EscapeFieldAnnotationKey) equals(other Key) bool { return false } +func (ek *EscapeFieldAnnotationKey) copy() Key { + copyKey := *ek + return ©Key +} + func (ek *EscapeFieldAnnotationKey) String() string { return fmt.Sprintf("escaped Field %s", ek.FieldDecl.Name()) } @@ -560,6 +608,11 @@ func (pf *ParamFieldAnnotationKey) equals(other Key) bool { return false } +func (pf *ParamFieldAnnotationKey) copy() Key { + copyKey := *pf + return ©Key +} + // String returns a string representation of this annotation key for ParamFieldAnnotationKey func (pf *ParamFieldAnnotationKey) String() string { argName := "" @@ -617,6 +670,11 @@ func (rk *RecvAnnotationKey) equals(other Key) bool { return false } +func (rk *RecvAnnotationKey) copy() Key { + copyKey := *rk + return ©Key +} + func (rk *RecvAnnotationKey) String() string { return fmt.Sprintf("Receiver of Method %s", rk.FuncDecl.Name()) } diff --git a/annotation/key_test.go b/annotation/key_test.go index cdf24e09..68df3f3e 100644 --- a/annotation/key_test.go +++ b/annotation/key_test.go @@ -20,32 +20,50 @@ import ( "github.com/stretchr/testify/suite" ) -type KeyTestSuite struct { - EqualsTestSuite -} +const _interfaceNameKey = "Key" -func (s *KeyTestSuite) SetupTest() { - s.interfaceName = "Key" - - // initialize all structs that implement the Key interface - s.initStructs = []any{ - &FieldAnnotationKey{}, - &CallSiteParamAnnotationKey{}, - &ParamAnnotationKey{}, - &CallSiteRetAnnotationKey{}, - &RetAnnotationKey{}, - &TypeNameAnnotationKey{}, - &GlobalVarAnnotationKey{}, - &RecvAnnotationKey{}, - &RetFieldAnnotationKey{}, - &EscapeFieldAnnotationKey{}, - &ParamFieldAnnotationKey{}, - } +// initStructsKey initializes all structs that implement the Key interface +var initStructsKey = []any{ + &FieldAnnotationKey{}, + &CallSiteParamAnnotationKey{}, + &ParamAnnotationKey{}, + &CallSiteRetAnnotationKey{}, + &RetAnnotationKey{}, + &TypeNameAnnotationKey{}, + &GlobalVarAnnotationKey{}, + &RecvAnnotationKey{}, + &RetFieldAnnotationKey{}, + &EscapeFieldAnnotationKey{}, + &ParamFieldAnnotationKey{}, } // TestKeyEqualsSuite runs the test suite for the `equals` method of all the structs that implement // the `Key` interface. +type KeyEqualsTestSuite struct { + EqualsTestSuite +} + +func (s *KeyEqualsTestSuite) SetupTest() { + s.interfaceName = _interfaceNameKey + s.initStructs = initStructsKey +} + func TestKeyEqualsSuite(t *testing.T) { t.Parallel() - suite.Run(t, new(KeyTestSuite)) + suite.Run(t, new(KeyEqualsTestSuite)) +} + +// TestKeyCopySuite runs the test suite for the `copy` method of all the structs that implement the `Key` interface. +type KeyCopyTestSuite struct { + CopyTestSuite +} + +func (s *KeyCopyTestSuite) SetupTest() { + s.interfaceName = _interfaceNameKey + s.initStructs = initStructsKey +} + +func TestKeyCopySuite(t *testing.T) { + t.Parallel() + suite.Run(t, new(KeyCopyTestSuite)) } diff --git a/assertion/function/assertiontree/backprop.go b/assertion/function/assertiontree/backprop.go index c30bb80f..9383d866 100644 --- a/assertion/function/assertiontree/backprop.go +++ b/assertion/function/assertiontree/backprop.go @@ -27,6 +27,7 @@ import ( "go.uber.org/nilaway/annotation" "go.uber.org/nilaway/config" "go.uber.org/nilaway/util" + "go.uber.org/nilaway/util/asthelper" "golang.org/x/exp/slices" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/cfg" @@ -585,7 +586,27 @@ buildShadowMask: } lhsNode, ok := rootNode.LiftFromPath(lpath) - if ok { + // TODO: below check for `lhsNode != nil` should not be needed when NilAway supports Ok form for + // used-defined functions (tracked issue #77) + if ok && lhsNode != nil { + // Add assignment entries to the consumers of lhsNode for informative printing of errors + for _, c := range lhsNode.ConsumeTriggers() { + var lhsExprStr, rhsExprStr string + var err error + if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + + c.Annotation.AddAssignment(annotation.Assignment{ + LHSExprStr: lhsExprStr, + RHSExprStr: rhsExprStr, + Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())), + }) + } + // If the lhsVal path is not only trackable but tracked, we add it as // a deferred landing landings = append(landings, deferredLanding{ @@ -609,10 +630,36 @@ buildShadowMask: rootNode.addProductionsForAssignmentFields(fieldProducers, lhsVal) } + // beforeTriggersLastIndex is used to find the newly added triggers on the next line + beforeTriggersLastIndex := len(rootNode.triggers) + rootNode.AddProduction(&annotation.ProduceTrigger{ Annotation: rproducers[0].GetShallow().Annotation, Expr: lhsVal, }, rproducers[0].GetDeepSlice()...) + + // Update consumers of newly added triggers with assignment entries for informative printing of errors + // TODO: the below check `len(rootNode.triggers) == 0` should not be needed, however, it is added to + // satisfy NilAway's analysis + if len(rootNode.triggers) == 0 { + continue + } + for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] { + var lhsExprStr, rhsExprStr string + var err error + if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil { + return err + } + + t.Consumer.Annotation.AddAssignment(annotation.Assignment{ + LHSExprStr: lhsExprStr, + RHSExprStr: rhsExprStr, + Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())), + }) + } default: return errors.New("rhs expression in a 1-1 assignment was multiply returning - " + "this certainly indicates an error in control flow") diff --git a/assertion/function/assertiontree/root_assertion_node.go b/assertion/function/assertiontree/root_assertion_node.go index 6f8cc4b2..a3fdcde2 100644 --- a/assertion/function/assertiontree/root_assertion_node.go +++ b/assertion/function/assertiontree/root_assertion_node.go @@ -896,8 +896,6 @@ func getFuncLitFromAssignment(ident *ast.Ident) *ast.FuncLit { // } // // ``` -// -// nilable(path, result 0) func (r *RootAssertionNode) LiftFromPath(path TrackableExpr) (AssertionNode, bool) { if path != nil { node, whichChild := r.lookupPath(path) diff --git a/assertion/function/assertiontree/util.go b/assertion/function/assertiontree/util.go index 5207a00a..c70e3ff5 100644 --- a/assertion/function/assertiontree/util.go +++ b/assertion/function/assertiontree/util.go @@ -389,7 +389,11 @@ func CopyNode(node AssertionNode) AssertionNode { fresh.SetChildren(append(fresh.Children(), freshChild)) } - fresh.SetConsumeTriggers(append(make([]*annotation.ConsumeTrigger, 0, len(node.ConsumeTriggers())), node.ConsumeTriggers()...)) + copyConsumers := make([]*annotation.ConsumeTrigger, 0, len(node.ConsumeTriggers())) + for _, c := range node.ConsumeTriggers() { + copyConsumers = append(copyConsumers, c.Copy()) + } + fresh.SetConsumeTriggers(copyConsumers) return fresh } diff --git a/nilaway_test.go b/nilaway_test.go index 1e253d33..3dce8adb 100644 --- a/nilaway_test.go +++ b/nilaway_test.go @@ -232,6 +232,13 @@ func TestConstants(t *testing.T) { analysistest.Run(t, testdata, Analyzer, "go.uber.org/consts") } +func TestErrorMessage(t *testing.T) { + t.Parallel() + + testdata := analysistest.TestData() + analysistest.Run(t, testdata, Analyzer, "go.uber.org/errormessage") +} + func TestPrettyPrint(t *testing.T) { //nolint:paralleltest // We specifically do not set this test to be parallel such that this test is run separately // from the parallel tests. This makes it possible to set the pretty-print flag to true for diff --git a/testdata/src/go.uber.org/errormessage/errormessage.go b/testdata/src/go.uber.org/errormessage/errormessage.go new file mode 100644 index 00000000..d424d7f0 --- /dev/null +++ b/testdata/src/go.uber.org/errormessage/errormessage.go @@ -0,0 +1,264 @@ +// Copyright (c) 2023 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This package tests _single_ package inference. Due to limitations of `analysistest` framework, +// multi-package inference is tested by our integration test suites. Please see +// `testdata/README.md` for more details. + +// +package errormessage + +import "errors" + +var dummy bool + +func test1(x *int) { + x = nil + print(*x) //want "`nil` to `x`" +} + +func test2(x *int) { + x = nil + y := x + z := y + print(*z) //want "`y` to `z`" +} + +func test3(x *int) { + if dummy { + x = nil + } else { + x = new(int) + } + y := x + z := y + print(*z) //want "`nil` to `x`" +} + +// nilable(f) +type S struct { + f *int +} + +func test4(x *int) { + s := &S{} + x = nil + y := x + z := y + s.f = z + print(*s.f) //want "`z` to `s.f`" +} + +func test5() { + x := new(int) + for i := 0; i < 10; i++ { + print(*x) //want "`nil` to `y`" + var y *int = nil + z := y + x = z + } +} + +func test6() *int { + var x *int = nil + y := x + z := y + return z //want "`nil` to `x`" +} + +func test7() { + var x *int + if dummy { + y := new(int) + x = y + } + print(*x) //want "unassigned variable `x` dereferenced" +} + +func test8() { + x := new(int) + if dummy { + var y *int + x = y + } + print(*x) //want "`y` to `x`" +} + +func test9(m map[int]*int) { + x, _ := m[0] + y := x + print(*y) //want "`m\\[0\\]` to `x`" +} + +func test10(ch chan *int) { + x := <-ch //want "nil channel accessed" + y := x + print(*y) +} + +func callTest10() { + var ch chan *int + test10(ch) +} + +func test11(s []*int) { + x := s[0] //want "`s` sliced into" + y := x + print(*y) +} + +func callTest11() { + var s []*int + test11(s) +} + +func test12(mp map[int]S, i int) { + x := mp[i] // unrelated assignment, should not be printed in the error message + _ = x + + y := mp[i] // unrelated assignment, should not be printed in the error message + _ = y + + s := mp[i] // relevant assignment, should be printed in the error message + consumeS(&s) //want "`mp\\[i\\]` to `s`" +} + +func consumeS(s *S) { + print(s.f) +} + +func retErr() error { + return errors.New("error") +} + +func test13() *int { + if err := retErr(); err != nil { // unrelated assignment, should not be printed in the error message + return nil //want "literal `nil` returned" + } + return new(int) +} + +// below tests check shortening of expressions in assignment messages + +// nilable(s, result 0) +func (s *S) bar(i int) *int { + return nil +} + +// nilable(result 0) +func (s *S) foo(a int, b *int, c string, d bool) *S { + return nil +} + +func test14(x *int, i int) { + s := &S{} + x = s.foo(1, + new(int), + "abc", + true).bar(i) + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(i\\)` to `x`" +} + +func test15(x *int) { + var longVarName, anotherLongVarName, yetAnotherLongName int + s := &S{} + x = s.foo(longVarName, &anotherLongVarName, "abc", true).bar(yetAnotherLongName) + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(...\\)` to `x`" +} + +func test16(mp map[int]*int) { + var aVeryVeryVeryLongIndexVar int + x := mp[aVeryVeryVeryLongIndexVar] + y := x + print(*y) //want "`mp\\[...\\]` to `x`" +} + +func test17(x *int, mp map[int]*int) { + var aVeryVeryVeryLongIndexVar int + s := &S{} + + x = s.foo(1, mp[aVeryVeryVeryLongIndexVar], "abc", true).bar(2) //want "deep read" + y := x + print(*y) //want "`s.foo\\(...\\).bar\\(2\\)` to `x`" +} + +func test18(x *int, mp map[int]*int) { + s := &S{} + x = mp[*(s.foo(1, new(int), "abc", true).bar(2))] //want "dereferenced" + y := x + print(*y) //want "`mp\\[...\\]` to `x`" +} + +func test19() { + mp := make(map[string]*string) + x := mp["("] + y := x + print(*y) //want "`mp\\[\"\\(\"\\]` to `x`" + + x = mp[")"] + y = x + print(*y) //want "`mp\\[\"\\)\"\\]` to `x`" + + x = mp["))"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["(("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp[")))((("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp[")))((("] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["(((()"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + x = mp["())))"] + y = x + print(*y) //want "`mp\\[...\\]` to `x`" + + s := &S{} + i := 0 + a := s.foo(1, + new(int), + "({[", + true).bar(i) + b := a + print(*b) //want "`s.foo\\(...\\).bar\\(i\\)` to `a`" +} + +func test20() { + mp := make(map[rune]*rune) + x := mp['('] + y := x + print(*y) //want "`mp\\['\\('\\]` to `x`" + + x = mp[')'] + y = x + print(*y) //want "`mp\\['\\)'\\]` to `x`" +} + +// below test checks that NilAway can handle non-English (non-ASCII) identifiers +func test21() { + var 世界 *int = nil + print(*世界) //want "`nil` to `世界`" +} diff --git a/util/asthelper/asthelper.go b/util/asthelper/asthelper.go new file mode 100644 index 00000000..7d757e57 --- /dev/null +++ b/util/asthelper/asthelper.go @@ -0,0 +1,104 @@ +// Copyright (c) 2023 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package asthelper implements utility functions for AST. +package asthelper + +import ( + "go/ast" + "go/printer" + "go/token" + "io" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// PrintExpr converts AST expression to string, and shortens long expressions if isShortenExpr is true +func PrintExpr(e ast.Expr, pass *analysis.Pass, isShortenExpr bool) (string, error) { + builder := &strings.Builder{} + var err error + + if !isShortenExpr { + err = printer.Fprint(builder, pass.Fset, e) + } else { + // traverse over the AST expression's subtree and shorten long expressions + // (e.g., s.foo(longVarName, anotherLongVarName, someOtherLongVarName) --> s.foo(...)) + err = printExpr(builder, pass.Fset, e) + } + + return builder.String(), err +} + +func printExpr(writer io.Writer, fset *token.FileSet, e ast.Expr) (err error) { + // _shortenExprLen is the maximum length of an expression to be printed in full. The value is set to 3 to account for + // the length of the ellipsis ("..."), which is used to shorten long expressions. + const _shortenExprLen = 3 + + // fullExpr returns true if the expression is short enough (<= _shortenExprLen) to be printed in full + fullExpr := func(node ast.Node) (string, bool) { + switch n := node.(type) { + case *ast.Ident: + if len(n.Name) <= _shortenExprLen { + return n.Name, true + } + case *ast.BasicLit: + if len(n.Value) <= _shortenExprLen { + return n.Value, true + } + } + return "", false + } + + switch node := e.(type) { + case *ast.Ident: + _, err = io.WriteString(writer, node.Name) + + case *ast.SelectorExpr: + if err = printExpr(writer, fset, node.X); err != nil { + return + } + _, err = io.WriteString(writer, "."+node.Sel.Name) + + case *ast.CallExpr: + if err = printExpr(writer, fset, node.Fun); err != nil { + return + } + var argStr string + if len(node.Args) > 0 { + argStr = "..." + if len(node.Args) == 1 { + if a, ok := fullExpr(node.Args[0]); ok { + argStr = a + } + } + } + _, err = io.WriteString(writer, "("+argStr+")") + + case *ast.IndexExpr: + if err = printExpr(writer, fset, node.X); err != nil { + return + } + + indexExpr := "..." + if v, ok := fullExpr(node.Index); ok { + indexExpr = v + } + _, err = io.WriteString(writer, "["+indexExpr+"]") + + default: + err = printer.Fprint(writer, fset, e) + } + return +}