Skip to content

Commit

Permalink
address code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalmahajan15 committed Oct 17, 2023
1 parent e66fcd9 commit 9a5bbf7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
13 changes: 9 additions & 4 deletions assertion/function/assertiontree/trusted_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ var requireComparators action = func(call *ast.CallExpr, startIndex int, pass *a
funcName := sel.Sel.Name

// We now find the actual and expected expressions, where expected is the constant value that actual expression is
// compared against. For example, in `Equal(1, len(s))`, expected is 1, and actual is `len(s)`. However, the position
// compared against. For example, in `Equal(1, len(s))`, expected is 1, and actual is `s`. However, the position
// of the actual and expected expressions can be swapped, e.g., `Equal(len(s), 1)`. We handle both cases below. For
// example, for length comparison, we search for the slice expression, the other will be treated as length expression.

Expand Down Expand Up @@ -253,6 +253,11 @@ var requireComparators action = func(call *ast.CallExpr, startIndex int, pass *a

// Now, based on the semantics of the function, we can create artificial nonnil checks for
// the following cases.
// - slice length comparison. E.g., `Equal(1, len(s))`, implying len(s) > 0, meaning s is nonnil.
// Here, actualExpr is `s` and expectedExprValue is `_greaterThanZero`, which translates to the binary expression
// `s != nil` being added to the CFG. Similarly, for `Equal(len(s), 0)`, we add `s == nil` to the CFG.
// - nil comparison. E.g., `Equal(nil, err)`, where actualExpr is `err` and expectedExprValue is `_nil`, which
// translates to the binary expression `err == nil` being added to the CFG.
switch funcName {
case "Equal", "Equalf": // len(s) == [positive_int], expr == nil
if expectedExprValue == _greaterThanZero {
Expand All @@ -265,7 +270,7 @@ var requireComparators action = func(call *ast.CallExpr, startIndex int, pass *a
return newNilBinaryExpr(actualExpr, token.NEQ)
}

// Note the check for `argIndex` in the following cases, we need to make sure the slice expr
// Note the check for `actualExprIndex` in the following cases, we need to make sure the slice expr
// is at the correct position since these are inequality checks.
case "Greater", "Greaterf": // len(s) > [non_negative_int]
if actualExprIndex == 0 && (expectedExprValue == _zero || expectedExprValue == _greaterThanZero) {
Expand Down Expand Up @@ -348,7 +353,7 @@ var trustedFuncs = map[trustedFuncSig]trustedFuncAction{
{
kind: _method,
enclosingRegex: regexp.MustCompile(`github\.com/stretchr/testify/(suite\.Suite|assert\.Assertions|require\.Assertions)$`),
funcNameRegex: regexp.MustCompile(`^(Greater(f)?|Less(f)?|(GreaterOr|LessOr)?Equal(f)?|NotEqual(f)?)$`),
funcNameRegex: regexp.MustCompile(`^(Greater(f)?|Less(f)?|Equal(f)?|GreaterOrEqual(f)?|LessOrEqual(f)?|NotEqual(f)?)$`),
}: {action: requireComparators, argIndex: 0},
{
kind: _method,
Expand Down Expand Up @@ -380,7 +385,7 @@ var trustedFuncs = map[trustedFuncSig]trustedFuncAction{
{
kind: _func,
enclosingRegex: regexp.MustCompile(`github\.com/stretchr/testify/(assert|require)$`),
funcNameRegex: regexp.MustCompile(`^(Greater(f)?|Less(f)?|(GreaterOr|LessOr)?Equal(f)?|NotEqual(f)?)$`),
funcNameRegex: regexp.MustCompile(`^(Greater(f)?|Less(f)?|Equal(f)?|GreaterOrEqual(f)?|LessOrEqual(f)?|NotEqual(f)?)$`),
}: {action: requireComparators, argIndex: 1},
{
kind: _func,
Expand Down
10 changes: 10 additions & 0 deletions testdata/src/go.uber.org/testing/trustedfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,16 @@ func testEqual(t *testing.T, i int, a []int) interface{} {
var x *int
require.Equal(t, nil, x)
print(*x) //want "unassigned variable `x` dereferenced"

case 16:
var x *int
require.Equal(t, x, nil)
print(*x) //want "unassigned variable `x` dereferenced"

case 17:
var x *int
require.NotEqual(t, x, nil)
print(*x)
}
return 0
}
Expand Down

0 comments on commit 9a5bbf7

Please sign in to comment.