Skip to content

Commit

Permalink
Support for handling of ok form for functions/methods (#157)
Browse files Browse the repository at this point in the history
This pull request introduces support for the `ok` form in both
user-defined and library functions and methods. The implementation
addresses false positives, such as those identified in issue #77.

Currently, the feature is designed to handle explicit boolean returns,
specifically in the form of `return r0, r1, ..., true`. Support for
expression-based returns (e.g., `return r0, r1, ..., flag` or `return
r0, r1, ..., isOk()`) is tricky, as tracking boolean types is currently
not supported in NilAway. We can handle this scenario in the future.

[Closes #77 ]
[Depends on #156 ]
  • Loading branch information
sonalmahajan15 authored Feb 8, 2024
1 parent 71d1444 commit c4313bf
Show file tree
Hide file tree
Showing 8 changed files with 974 additions and 43 deletions.
8 changes: 4 additions & 4 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err
// this nil check reflects programmer logic
return errors.New("producers variable is nil")
}

isErrReturning := util.FuncIsErrReturning(funcObj)

// since we don't individually track the returns of a multiply returning function,
// we form full triggers for each return whose type doesn't bar nilness
if !util.TypeBarsNilness(funcObj.Type().(*types.Signature).Results().At(i).Type()) {
isErrReturning := util.FuncIsErrReturning(funcObj)
isOkReturning := util.FuncIsOkReturning(funcObj)

rootNode.AddNewTriggers(annotation.FullTrigger{
Producer: &annotation.ProduceTrigger{
// since the value is being returned directly, only its shallow nilability
Expand All @@ -225,7 +225,7 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err
// if an error returning function returns directly as the result of
// another error returning function, then its results can safely be
// interpreted as guarded
GuardMatched: isErrReturning,
GuardMatched: isErrReturning || isOkReturning,
},
})
}
Expand Down
60 changes: 54 additions & 6 deletions assertion/function/assertiontree/backprop_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"go/ast"
"go/constant"
"go/token"
"go/types"

Expand Down Expand Up @@ -115,6 +116,7 @@ func exprCallsKnownNilableErrFunc(expr ast.Expr) bool {
// in particular, this function is responsible for splitting returns into the cases:
// 1: Normal Return - all results yield consume triggers eventually enforcing their annotated/inferred nilability
// 2: Error Return - consume triggers are created based on the error contract. i.e., based on the nilabiity status of the error return expression
// 3. Ok return - consume triggers are created based on the nilability status of the boolean (`ok`) return expression
func computeAndConsumeResults(rootNode *RootAssertionNode, node *ast.ReturnStmt) error {
// no matter what case the consumption of these returns ends up as - each must be computed
for i := range node.Results {
Expand All @@ -135,8 +137,11 @@ func computeAndConsumeResults(rootNode *RootAssertionNode, node *ast.ReturnStmt)
}

// if the function has named error return variable, then handle specially using the error handling logic
if util.FuncIsErrReturning(rootNode.FuncObj()) {
handleErrorReturns(rootNode, node, results, true /* isNamedReturn */)
if ok := handleErrorReturns(rootNode, node, results, true /* isNamedReturn */); ok {
return nil
}

if ok := handleBooleanReturns(rootNode, node, results, true /* isNamedReturn */); ok {
return nil
}

Expand Down Expand Up @@ -199,8 +204,10 @@ func computeAndConsumeResults(rootNode *RootAssertionNode, node *ast.ReturnStmt)
)
}

if util.FuncIsErrReturning(rootNode.FuncObj()) {
handleErrorReturns(rootNode, node, node.Results, false /* isNamedReturn */)
if ok := handleErrorReturns(rootNode, node, node.Results, false /* isNamedReturn */); ok {
return nil
}
if ok := handleBooleanReturns(rootNode, node, node.Results, false /* isNamedReturn */); ok {
return nil
}

Expand Down Expand Up @@ -263,7 +270,11 @@ func isErrorReturnNonnil(rootNode *RootAssertionNode, errRet ast.Expr) bool {
// (3) if error return value = unknown, create consumers for all returns (error and non-error), and defer applying of the error contract when the nilability status is known, such as at `ProcessEntry`
//
// Note that `results` should be explicitly passed since `retStmt` of a named return will contain no results
func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, results []ast.Expr, isNamedReturn bool) {
func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, results []ast.Expr, isNamedReturn bool) bool {
if !util.FuncIsErrReturning(rootNode.FuncObj()) {
return false
}

errRetIndex := len(results) - 1
errRetExpr := results[errRetIndex] // n-th expression
nonErrRetExpr := results[:errRetIndex] // n-1 expressions
Expand All @@ -272,7 +283,7 @@ func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, re
for _, r := range nonErrRetExpr {
if util.ExprBarsNilness(rootNode.Pass(), r) {
// no need to further analyze and create triggers
return
return true
}
}

Expand Down Expand Up @@ -305,6 +316,43 @@ func handleErrorReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, re
}
}
}
return true
}

// handleBooleanReturns handles the special case for boolean (`ok`) returning functions (n-th result of type `bool`
// which guards at least one of the first n-1 non-bool results). Similar to the handling of error returning functions,
// for boolean returns, we generate consumers by applying the following boolean contract:
// (1) if boolean return value = true, create consumers for the non-boolean returns
// TODO: currently we support only explicit boolean returns (i.e., `return r0, r1, ..., {true|false}`). We should also support implicit boolean returns, i.e., `return` or `return <expr>` in the future.
//
// handleBooleanReturns returns true if the above contract is satisfied and consumers are created, false otherwise
func handleBooleanReturns(rootNode *RootAssertionNode, retStmt *ast.ReturnStmt, results []ast.Expr, isNamedReturn bool) bool {
// FuncIsOkReturning checks that the length of the results defined for the current function is at least 2, and that
// the last return type is a boolean, the value of which can be determined at compile time (e.g., return true)
if !util.FuncIsOkReturning(rootNode.FuncObj()) {
return false
}

nRetIndex := len(results) - 1
nRetExpr := results[nRetIndex] // n-th expression
nMinusOneRetExpr := results[:nRetIndex] // n-1 expressions

// check if the return statement is of the currently supported explicit boolean return form (`return ..., {true|false}`)
typeAndValue, ok := rootNode.Pass().TypesInfo.Types[nRetExpr]
if !ok {
return false
}
val, ok := constant.Val(typeAndValue.Value).(bool)
if !ok {
return false
}

// If return is "true", then track its n-1 returns. Create return consume triggers for all n-1 return expressions.
// If return is "false", then do nothing, since we don't track boolean values.
if val {
createGeneralReturnConsumers(rootNode, nMinusOneRetExpr, retStmt, isNamedReturn)
}
return true
}

// createConsumerForErrorReturn creates a consumer for the error return enforcing it to be non-nil
Expand Down
78 changes: 62 additions & 16 deletions assertion/function/assertiontree/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ type RichCheckEffect interface {
// assumed nilable
//
// For proper invalidation, each stored return of a function is treated as a separate effect
// nonnil(err, ret)
type FuncErrRet struct {
root *RootAssertionNode // an associated root node
err TrackableExpr // the `error`-typed return of the function
Expand Down Expand Up @@ -100,16 +99,15 @@ func (f *FuncErrRet) equals(effect RichCheckEffect) bool {
f.guard == otherFuncErrRet.guard
}

// okRead provides a general implementation for the special return form: `v1, v2 := expr`.
// okRead provides a general implementation for the special return form: `v1, v2, ..., ok := expr`.
// Concrete examples of patterns supported are:
// - map ok read: `v, ok := m[k]`
// - channel ok receive: `v, ok := <-ch`
// - function error return: `r0, r1, r2, ..., err := f()`
// nonnil(value, ok)
// - function ok return: `r0, r1, r2, ..., ok := f()`
type okRead struct {
root *RootAssertionNode // an associated root node
value TrackableExpr // `value` could be a value for read from a map or channel, or the return value of a function
ok TrackableExpr // `ok` is boolean "ok" for read from a map or channel, or "err" for return from a function
ok TrackableExpr // `ok` is boolean "ok" for read from a map or channel, or return from a function
guard util.GuardNonce // the guard to be applied on a matching check
}

Expand Down Expand Up @@ -173,6 +171,13 @@ type ChannelOkRecvRefl struct {
okRead
}

// A FuncOkReturn is a RichCheckEffect for the `ok` in `r0, r1, r2, ..., ok := f()`, where the
// function `f` has a final result of type `bool` - and until this is checked all other results are
// assumed nilable. For proper invalidation, each stored return of a function is treated as a separate effect
type FuncOkReturn struct {
okRead
}

// A RichCheckNoop is a placeholder instance of RichCheckEffect that functions as a total noop.
// It is used to allow in place modification of collections of RichCheckEffects.
type RichCheckNoop struct{}
Expand All @@ -195,7 +200,6 @@ func (RichCheckNoop) equals(effect RichCheckEffect) bool {
// RichCheckFromNode analyzes the passed `ast.Node` to see if it generates a rich check effect.
// If it does, that effect is returned along with the boolean true
// If it does not, then `nil, false` is returned.
// nilable(result 0)
func RichCheckFromNode(rootNode *RootAssertionNode, nonceGenerator *util.GuardNonceGenerator, node ast.Node) ([]RichCheckEffect, bool) {
var effects []RichCheckEffect
someEffects := false
Expand Down Expand Up @@ -228,18 +232,18 @@ func parseExpr(rootNode *RootAssertionNode, expr ast.Expr) TrackableExpr {
return parsed
}

// NodeTriggersOkRead is a case of a node creating a rich bool effect for map read and channel receive in the "ok" form.
// It matches on `AssignStmt`s of the form `v, ok := mp[k]` and `v, ok := <-ch`
// nilable(result 0)
// NodeTriggersOkRead is a case of a node creating a rich bool effect for map reads, channel receives, and user-defined
// functions in the "ok" form. Specifically, it matches on `AssignStmt`s of the form
// - `v, ok := mp[k]`
// - `v, ok := <-ch`
// - `r0, r1, r2, ..., ok := f()`
func NodeTriggersOkRead(rootNode *RootAssertionNode, nonceGenerator *util.GuardNonceGenerator, node ast.Node) ([]RichCheckEffect, bool) {
lhs, rhs := asthelper.ExtractLHSRHS(node)
if len(lhs) != 2 || len(rhs) != 1 {
if len(lhs) < 2 || len(rhs) != 1 {
return nil, false
}

valueExpr := lhs[0]
okExpr := lhs[1]
lhsValueParsed := parseExpr(rootNode, valueExpr)
okExpr := lhs[len(lhs)-1]
lhsOkParsed := parseExpr(rootNode, okExpr)
if lhsOkParsed == nil {
// here, the lhs `ok` operand is not trackable so there are no rich effects
Expand All @@ -250,16 +254,22 @@ func NodeTriggersOkRead(rootNode *RootAssertionNode, nonceGenerator *util.GuardN

switch rhs := rhs[0].(type) {
case *ast.IndexExpr:
// this is the case of `v, ok := mp[k]`. Early return if the lhs is not a map read of the expected format
if len(lhs) != 2 {
return nil, false
}

rhsXType := rootNode.Pass().TypesInfo.Types[rhs.X].Type
if util.TypeIsDeeplyMap(rhsXType) {
lhsValueParsed := parseExpr(rootNode, lhs[0])
if lhsValueParsed != nil {
// here, the lhs `value` operand is trackable
effects = append(effects, &MapOkRead{
okRead{
root: rootNode,
value: lhsValueParsed,
ok: lhsOkParsed,
guard: nonceGenerator.Next(valueExpr),
guard: nonceGenerator.Next(lhs[0]),
}})
}

Expand All @@ -275,16 +285,22 @@ func NodeTriggersOkRead(rootNode *RootAssertionNode, nonceGenerator *util.GuardN
}
}
case *ast.UnaryExpr:
// this is the case of `v, ok := <-ch`. Early return if the lhs is not a channel receive of the expected format
if len(lhs) != 2 {
return nil, false
}

rhsXType := rootNode.Pass().TypesInfo.Types[rhs.X].Type
if rhs.Op == token.ARROW && util.TypeIsDeeplyChan(rhsXType) {
lhsValueParsed := parseExpr(rootNode, lhs[0])
if lhsValueParsed != nil {
// here, the lhs `value` operand is trackable
effects = append(effects, &ChannelOkRecv{
okRead{
root: rootNode,
value: lhsValueParsed,
ok: lhsOkParsed,
guard: nonceGenerator.Next(valueExpr),
guard: nonceGenerator.Next(lhs[0]),
}})
}

Expand All @@ -299,6 +315,37 @@ func NodeTriggersOkRead(rootNode *RootAssertionNode, nonceGenerator *util.GuardN
}})
}
}
case *ast.CallExpr:
callIdent := util.FuncIdentFromCallExpr(rhs)
if callIdent == nil {
// this discards the case of an anonymous function
// perhaps in the future we could change this
return nil, false
}

rhsFuncDecl, ok := rootNode.ObjectOf(callIdent).(*types.Func)

if !ok || !util.FuncIsOkReturning(rhsFuncDecl) {
return nil, false
}

// we've found an assignment of vars to an "ok" form function!
for i := 0; i < len(lhs)-1; i++ {
lhsExpr := lhs[i]
lhsValueParsed := parseExpr(rootNode, lhsExpr)
if lhsValueParsed == nil || util.ExprBarsNilness(rootNode.Pass(), lhsExpr) {
// ignore assignments to any variables whose type bars nilness, such as 'int'
continue
}
// here, the lhs `value` operand is trackable
effects = append(effects, &FuncOkReturn{
okRead{
root: rootNode,
value: lhsValueParsed,
ok: lhsOkParsed,
guard: nonceGenerator.Next(lhs[i]),
}})
}
}
if len(effects) > 0 {
return effects, true
Expand All @@ -308,7 +355,6 @@ func NodeTriggersOkRead(rootNode *RootAssertionNode, nonceGenerator *util.GuardN

// NodeTriggersFuncErrRet is a case of a node creating a rich check effect.
// it matches on calls to functions with error-returning types
// nilable(result 0)
func NodeTriggersFuncErrRet(rootNode *RootAssertionNode, nonceGenerator *util.GuardNonceGenerator, node ast.Node) ([]RichCheckEffect, bool) {
lhs, rhs := asthelper.ExtractLHSRHS(node)

Expand Down
4 changes: 3 additions & 1 deletion assertion/function/assertiontree/parse_expr_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C

numResults := util.FuncNumResults(funcObj)
isErrReturning := util.FuncIsErrReturning(funcObj)
isOkReturning := util.FuncIsOkReturning(funcObj)

producers := make([]producer.ParsedProducer, numResults)

Expand All @@ -436,10 +437,11 @@ func (r *RootAssertionNode) getFuncReturnProducers(ident *ast.Ident, expr *ast.C
Annotation: &annotation.FuncReturn{
TriggerIfNilable: &annotation.TriggerIfNilable{
Ann: retKey,

// for an error-returning function, all but the last result are guarded
// TODO: add an annotation that allows more results to escape from guarding
// such as "error-nonnil" or "always-nonnil"
NeedsGuard: isErrReturning && i != numResults-1,
NeedsGuard: (isErrReturning || isOkReturning) && i != numResults-1,
},
},
Expr: expr,
Expand Down
2 changes: 1 addition & 1 deletion nilaway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestContracts(t *testing.T) {
t.Parallel()

testdata := analysistest.TestData()
analysistest.Run(t, testdata, Analyzer, "go.uber.org/contracts")
analysistest.Run(t, testdata, Analyzer, "go.uber.org/contracts", "go.uber.org/contracts/namedtypes")
}

func TestTesting(t *testing.T) {
Expand Down
44 changes: 44 additions & 0 deletions testdata/src/go.uber.org/contracts/namedtypes/namedtypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package namedtypes: This package tests that named types are correctly handled for contract types (e.g., error
// returning functions and ok-form for functions)
package namedtypes

// the below test uses the built-in name `bool` for creating a user-defined named type. However, the logic for determining
// an ok-form function should not depend on the name `bool`, but the underlying type. This test ensures that the logic.
type bool int

func retPtrBoolNamed() (*int, bool) {
return nil, 0
}

func testNamedBool() {
if v, ok := retPtrBoolNamed(); ok == 0 {
_ = *v // want "dereferenced"
}
}

// Similar to the above test, but with the built-in name `error`
type error int

func retPtrErrorNamed() (*int, error) {
return nil, 0
}

func testNamedError() {
if v, ok := retPtrErrorNamed(); ok == 0 {
_ = *v // want "dereferenced"
}
}
Loading

0 comments on commit c4313bf

Please sign in to comment.