Skip to content

Commit

Permalink
Merge branch 'main' into handle-init-func
Browse files Browse the repository at this point in the history
  • Loading branch information
k4n4ry authored Aug 11, 2024
2 parents dce532a + 4f33d8c commit 00ccc33
Show file tree
Hide file tree
Showing 24 changed files with 1,016 additions and 95 deletions.
9 changes: 9 additions & 0 deletions accumulation/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,16 @@ func checkErrors(triggers []annotation.FullTrigger, annMap annotation.Map, diagn
},
)

// Delete all "always safe" special handlers, since they are not meant to be tested for the no infer case
finalTriggers := make([]annotation.FullTrigger, 0, len(filteredTriggers))
for _, trigger := range filteredTriggers {
if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe {
continue
}
finalTriggers = append(finalTriggers, trigger)
}

for _, trigger := range finalTriggers {
// Skip checking any full triggers we created by duplicating from contracted functions
// to the caller function.
if !trigger.CreatedFromDuplication && trigger.Check(annMap) {
Expand Down
6 changes: 4 additions & 2 deletions annotation/consume_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,15 +1092,17 @@ func DuplicateReturnConsumer(t *ConsumeTrigger, location token.Position) *Consum
// used for functions with contracts since we need to duplicate the sites for context sensitivity.
type UseAsReturn struct {
*TriggerIfNonNil
IsNamedReturn bool
RetStmt *ast.ReturnStmt
IsNamedReturn bool
IsTrackingAlwaysSafe bool
RetStmt *ast.ReturnStmt
}

// equals returns true if the passed ConsumingAnnotationTrigger is equal to this one
func (u *UseAsReturn) equals(other ConsumingAnnotationTrigger) bool {
if other, ok := other.(*UseAsReturn); ok {
return u.TriggerIfNonNil.equals(other.TriggerIfNonNil) &&
u.IsNamedReturn == other.IsNamedReturn &&
u.IsTrackingAlwaysSafe == other.IsTrackingAlwaysSafe &&
u.RetStmt == other.RetStmt
}
return false
Expand Down
5 changes: 4 additions & 1 deletion annotation/produce_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,15 @@ func (f FldReturnPrestring) String() string {
// context sensitivity.
type FuncReturn struct {
*TriggerIfNilable

IsFromRichCheckEffectFunc bool
}

// equals returns true if the passed ProducingAnnotationTrigger is equal to this one
func (f *FuncReturn) equals(other ProducingAnnotationTrigger) bool {
if other, ok := other.(*FuncReturn); ok {
return f.TriggerIfNilable.equals(other.TriggerIfNilable)
return f.TriggerIfNilable.equals(other.TriggerIfNilable) &&
f.IsFromRichCheckEffectFunc == other.IsFromRichCheckEffectFunc
}
return false
}
Expand Down
12 changes: 6 additions & 6 deletions assertion/function/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ func duplicateFullTriggersFromContractedFunctionsToCallers(
for ctrtFunc, calls := range callsByCtrtFunc {
r := funcResults[ctrtFunc]
if r == nil {
// should not happen since funcResults should contain all the functions including any
// contracted functions.
panic(fmt.Sprintf("Did not find the contracted function %s in funcResults", ctrtFunc.Id()))
// The contracted function is imported from upstream, and the local package analysis
// does not involve it.
continue
}
for _, trigger := range r.triggers {
// If the full trigger has a FuncParam producer or a UseAsReturn consumer, then create
Expand Down Expand Up @@ -312,9 +312,9 @@ func duplicateFullTriggersFromContractedFunctionsToCallers(
for funcObj, triggers := range dupTriggers {
r := funcResults[funcObj]
if r == nil {
// should not happen since funcResults should contain all the functions including any
// contracted functions.
panic(fmt.Sprintf("Did not find the contracted function %s in funcResults", funcObj.Id()))
// Should not happen since we would not have created the duplicated triggers if the
// contracted function is not involved in the analysis of local package.
panic(fmt.Sprintf("did not find the contracted function %s in funcResults", funcObj.Id()))
}
funcTriggers[r.index] = append(funcTriggers[r.index], triggers...)
}
Expand Down
26 changes: 24 additions & 2 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err
isErrReturning := util.FuncIsErrReturning(funcObj)
isOkReturning := util.FuncIsOkReturning(funcObj)

rootNode.AddNewTriggers(annotation.FullTrigger{
trigger := annotation.FullTrigger{
Producer: &annotation.ProduceTrigger{
// since the value is being returned directly, only its shallow nilability
// matters (but deep would matter if we were enforcing correct variance)
Expand All @@ -230,7 +230,29 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err
// interpreted as guarded
GuardMatched: isErrReturning || isOkReturning,
},
})
}

// This is a duplicate trigger for tracking "always safe" paths. The analysis of these triggers
// will be processed at the inference stage.
triggerAlwaysSafe := annotation.FullTrigger{
Producer: trigger.Producer,
Consumer: &annotation.ConsumeTrigger{
Annotation: &annotation.UseAsReturn{
TriggerIfNonNil: &annotation.TriggerIfNonNil{
Ann: annotation.RetKeyFromRetNum(
rootNode.ObjectOf(rootNode.FuncNameIdent()).(*types.Func),
i,
)},
RetStmt: node,
IsTrackingAlwaysSafe: true,
},
Expr: trigger.Consumer.Expr,
Guards: trigger.Consumer.Guards,
GuardMatched: trigger.Consumer.GuardMatched,
},
}

rootNode.AddNewTriggers(trigger, triggerAlwaysSafe)
}
}
rootNode.AddComputation(call)
Expand Down
32 changes: 32 additions & 0 deletions assertion/function/assertiontree/backprop_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, re
errRetExpr := results[errRetIndex] // n-th expression
nonErrRetExpr := results[:errRetIndex] // n-1 expressions

// default tracking to support potential "always safe" cases
createReturnConsumersForAlwaysSafe(rootNode, nonErrRetExpr, retStmt, isNamedReturn)

// check if the error return is at all guarding any nilable returns, such as pointers, maps, and slices
if isErrorReturnNil(rootNode, errRetExpr) {
// if error is the only return expression in the statement, then create a consumer for it, else create consumers for the non-error return expressions
Expand Down Expand Up @@ -329,6 +332,9 @@ func handleBooleanReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt,
return false
}

// default tracking to support potential "always safe" cases
createReturnConsumersForAlwaysSafe(rootNode, nMinusOneRetExpr, retStmt, isNamedReturn)

// If return is "true", then track its n-1 returns. Create return consume triggers for all n-1 return expressions.
// If return is "false", then do nothing, since we don't track boolean values.
if val {
Expand Down Expand Up @@ -371,6 +377,32 @@ func createGeneralReturnConsumers(rootNode *RootAssertionNode, results []ast.Exp
}
}

// createReturnConsumersForAlwaysSafe creates return consumers for the non-return expressions in the return statement
// for tracking potential "always safe" cases
func createReturnConsumersForAlwaysSafe(rootNode *RootAssertionNode, nonErrResults []ast.Expr, retStmt *ast.ReturnStmt, isNamedReturn bool) {
for i := range nonErrResults {
// don't do anything if the expression is a blank identifier ("_")
if util.IsEmptyExpr(nonErrResults[i]) {
continue
}

rootNode.AddConsumption(&annotation.ConsumeTrigger{
Annotation: &annotation.UseAsReturn{
TriggerIfNonNil: &annotation.TriggerIfNonNil{
Ann: &annotation.RetAnnotationKey{
FuncDecl: rootNode.FuncObj(),
RetNum: i,
},
},
IsNamedReturn: isNamedReturn,
IsTrackingAlwaysSafe: true,
RetStmt: retStmt},
Expr: nonErrResults[i],
Guards: util.NoGuards(),
})
}
}

// createSpecialConsumersForAllReturns conservatively creates specially designed consumers for all return expressions, error and non-error
func createSpecialConsumersForAllReturns(rootNode *RootAssertionNode, nonErrRetExpr []ast.Expr, errRetExpr ast.Expr, errRetIndex int, retStmt *ast.ReturnStmt, isNamedReturn bool) {
for i := range nonErrRetExpr {
Expand Down
1 change: 1 addition & 0 deletions assertion/function/assertiontree/parse_expr_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C
// such as "error-nonnil" or "always-nonnil"
NeedsGuard: (isErrReturning || isOkReturning) && i != numResults-1,
},
IsFromRichCheckEffectFunc: isErrReturning || isOkReturning,
},
Expr: expr,
},
Expand Down
13 changes: 12 additions & 1 deletion assertion/function/assertiontree/root_assertion_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,18 @@ func (r *RootAssertionNode) AddConsumption(consumer *annotation.ConsumeTrigger)
path, producers := r.ParseExprAsProducer(consumer.Expr, false)
if path == nil { // expr is not trackable
if producers == nil {
return // expr is not trackable, but cannot be nil, so do nothing
// Here we can infer that the expression is non-nil by definition. Instead of ignoring creation of a trigger,
// particularly for always safe tracking, we create a trigger with ProduceTriggerNever.
if c, ok := consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe {
r.AddNewTriggers(annotation.FullTrigger{
Producer: &annotation.ProduceTrigger{
Annotation: &annotation.ProduceTriggerNever{},
Expr: consumer.Expr,
},
Consumer: consumer,
})
}
return
}
if len(producers) != 1 {
panic("multiply-returning function call was passed to AddConsumption")
Expand Down
62 changes: 57 additions & 5 deletions assertion/function/functioncontracts/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,78 @@ var Analyzer = &analysis.Analyzer{
Doc: _doc,
Run: analysishelper.WrapRun(run),
ResultType: reflect.TypeOf((*analysishelper.Result[Map])(nil)),
FactTypes: []analysis.Fact{new(Contracts)},
Requires: []*analysis.Analyzer{config.Analyzer, buildssa.Analyzer},
}

// Contracts represents the list of contracts for a function.
type Contracts []Contract

// AFact enables use of the facts passing mechanism in Go's analysis framework.
func (*Contracts) AFact() {}

// Map stores the mappings from *types.Func to associated function contracts.
type Map map[*types.Func]Contracts

func run(pass *analysis.Pass) (Map, error) {
conf := pass.ResultOf[config.Analyzer].(*config.Config)

if !conf.IsPkgInScope(pass.Pkg) {
return Map{}, nil
return make(Map), nil
}

// Collect contracts from the current package.
contracts, err := collectFunctionContracts(pass)
if err != nil {
return nil, err
}

// The fact mechanism only allows exporting pointer types. However, internally we are using
// `Contract` as a value type because it is an underlying slice type (such that making it a
// pointer type will make the rest of the logic more complicated). Therefore, we strictly
// only convert it from/to a pointer type _here_ during the fact import/exports. Everywhere
// else in NilAway (this sub-analyzer, as well as the other analyzers) we treat `Contract`
// simply as a value type.

// Import contracts from upstream packages and merge it with the local contract map.
for _, fact := range pass.AllObjectFacts() {
fn, ok := fact.Object.(*types.Func)
if !ok {
continue
}
ctrts, ok := fact.Fact.(*Contracts)
if !ok || ctrts == nil {
continue
}
// The existing contracts are imported from upstream packages about upstream functions,
// therefore there should not be any conflicts with contracts collected from the current package.
if _, ok := contracts[fn]; ok {
return nil, fmt.Errorf("function %s has multiple contracts", fn.Name())
}
contracts[fn] = *ctrts
}

// Now, export the contracts for the _exported_ functions in the current package only.
for fn, ctrts := range contracts {
// Check if the function is (1) exported by name (i.e., starts with a capital letter), (2)
// it is directly inside the package scope (such that it is really visible in downstream
// packages).
if fn.Exported() &&
// fn.Scope() -> the scope of the function body.
fn.Scope() != nil &&
// fn.Scope().Parent() -> the scope of the file.
fn.Scope().Parent() != nil &&
// fn.Scope().Parent().Parent() -> the scope of the package.
fn.Scope().Parent().Parent() == pass.Pkg.Scope() {
pass.ExportObjectFact(fn, &ctrts)
}
}
return contracts, nil
}

// functionResult is the struct that is received from the channel for each function.
type functionResult struct {
funcObj *types.Func
contracts []*FunctionContract
contracts Contracts
err error
}

Expand All @@ -72,8 +123,9 @@ type functionResult struct {
// the comments at the top of each function. Only when there are no handwritten contracts there,
// do we try to automatically infer contracts.
func collectFunctionContracts(pass *analysis.Pass) (Map, error) {
// Collect ssa for every function.
conf := pass.ResultOf[config.Analyzer].(*config.Config)

// Collect ssa for every function.
ssaInput := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
ssaOfFunc := make(map[*types.Func]*ssa.Function, len(ssaInput.SrcFuncs))
for _, fnssa := range ssaInput.SrcFuncs {
Expand Down Expand Up @@ -155,7 +207,7 @@ func collectFunctionContracts(pass *analysis.Pass) (Map, error) {
defer func() {
if r := recover(); r != nil {
e := fmt.Errorf("INTERNAL PANIC: %s\n%s", r, string(debug.Stack()))
funcChan <- functionResult{err: e, funcObj: funcObj, contracts: []*FunctionContract{}}
funcChan <- functionResult{err: e, funcObj: funcObj}
}
}()

Expand Down
Loading

0 comments on commit 00ccc33

Please sign in to comment.