Skip to content

Commit

Permalink
refactors and copy test
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalmahajan15 committed Nov 5, 2023
1 parent b90fccf commit 597e0e6
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 87 deletions.
4 changes: 2 additions & 2 deletions annotation/consume_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (t *TriggerIfNonNil) equals(other ConsumingAnnotationTrigger) bool {
// Copy returns a deep copy of this ConsumingAnnotationTrigger
func (t *TriggerIfNonNil) Copy() ConsumingAnnotationTrigger {
copyConsumer := *t
t.Ann = t.Ann.copy()
copyConsumer.Ann = t.Ann.copy()
return &copyConsumer
}

Expand Down Expand Up @@ -140,7 +140,7 @@ func (t *TriggerIfDeepNonNil) equals(other ConsumingAnnotationTrigger) bool {
// Copy returns a deep copy of this ConsumingAnnotationTrigger
func (t *TriggerIfDeepNonNil) Copy() ConsumingAnnotationTrigger {
copyConsumer := *t
t.Ann = t.Ann.copy()
copyConsumer.Ann = t.Ann.copy()
return &copyConsumer
}

Expand Down
108 changes: 61 additions & 47 deletions annotation/consume_trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,76 @@ package annotation
import (
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)

type ConsumingAnnotationTriggerTestSuite struct {
EqualsTestSuite
const _interfaceNameConsumingAnnotationTrigger = "ConsumingAnnotationTrigger"

// initStructsConsumingAnnotationTrigger initializes all structs that implement the ConsumingAnnotationTrigger interface
var initStructsConsumingAnnotationTrigger = []any{
&TriggerIfNonNil{Ann: newMockKey()},
&TriggerIfDeepNonNil{Ann: newMockKey()},
&ConsumeTriggerTautology{},
&PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
}

func (s *ConsumingAnnotationTriggerTestSuite) SetupTest() {
s.interfaceName = "ConsumingAnnotationTrigger"
// ConsumingAnnotationTriggerEqualsTestSuite tests for the `equals` method of all the structs that implement
// the `ConsumingAnnotationTrigger` interface.
type ConsumingAnnotationTriggerEqualsTestSuite struct {
EqualsTestSuite
}

mockedKey := new(mockKey)
mockedKey.On("equals", mock.Anything).Return(true)
func (s *ConsumingAnnotationTriggerEqualsTestSuite) SetupTest() {
s.interfaceName = _interfaceNameConsumingAnnotationTrigger
s.initStructs = initStructsConsumingAnnotationTrigger
}

// initialize all structs that implement ConsumingAnnotationTrigger
s.initStructs = []any{
&TriggerIfNonNil{Ann: mockedKey},
&TriggerIfDeepNonNil{Ann: mockedKey},
&ConsumeTriggerTautology{},
&PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
}
func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(ConsumingAnnotationTriggerEqualsTestSuite))
}

// TestConsumingAnnotationTriggerEqualsSuite runs the test suite for the `equals` method of all the structs that implement
// ConsumingAnnotationTriggerCopyTestSuite tests for the `copy` method of all the structs that implement
// the `ConsumingAnnotationTrigger` interface.
func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) {
type ConsumingAnnotationTriggerCopyTestSuite struct {
CopyTestSuite
}

func (s *ConsumingAnnotationTriggerCopyTestSuite) SetupTest() {
s.interfaceName = _interfaceNameConsumingAnnotationTrigger
s.initStructs = initStructsConsumingAnnotationTrigger
}
func TestConsumingAnnotationTriggerCopySuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(ConsumingAnnotationTriggerTestSuite))
suite.Run(t, new(ConsumingAnnotationTriggerCopyTestSuite))
}
121 changes: 121 additions & 0 deletions annotation/copy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package annotation

import (
"fmt"
"reflect"
"strings"

"github.com/stretchr/testify/suite"
)

type CopyTestSuite struct {
suite.Suite
initStructs []any
interfaceName string
packagePath string
}

type objInfo struct {
addr string
numFields int
typ reflect.Type
}

func newObjInfo(addr string, numFields int, typ reflect.Type) objInfo {
return objInfo{
addr: addr,
numFields: numFields,
typ: typ,
}
}

// getObjInfo is a helper function that returns a map of struct and field names to their objInfo.
// The key is in the format of `struct_<struct name>` or `fld_<struct name>.<field name>`.
func getObjInfo(obj any) map[string]objInfo {
ptr := make(map[string]objInfo)

val := reflect.ValueOf(obj).Elem()
ptr[fmt.Sprintf("struct_%s", val.Type().Name())] = newObjInfo(fmt.Sprintf("%p", val.Addr().Interface()), val.NumField(), val.Type())
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
key := fmt.Sprintf("fld_%s.%s", val.Type().Name(), val.Type().Field(i).Name)
if field.Kind() == reflect.Ptr {
if !field.IsZero() {
ptr[key] = newObjInfo(fmt.Sprintf("%p", field.Interface()), field.Elem().NumField(), field.Elem().Type())
}
} else if field.Kind() == reflect.Interface && !field.IsNil() {
// %p cannot be used directly with a reflect.Value, so we need to extract the underlying value first.
interfaceValue := field.Interface()
underlyingValue := reflect.ValueOf(interfaceValue).Elem()
ptr[key] = newObjInfo(fmt.Sprintf("%p", underlyingValue.Addr().Interface()), underlyingValue.NumField(), underlyingValue.Type())
} else {
ptr[key] = newObjInfo("", 0, field.Type())
}
}
return ptr
}

// This test checks that the `Copy` method implementations perform a deep copy, i.e., copies the values but generates
// different pointer addresses for the copied struct and its fields.
// Note that here we cannot use `reflect.DeepEqual` to compare the original and copied structs because reflection
// does not work well with fields with nested struct pointers, giving incorrect results.
// Therefore, we compare the original and copied structs along with their fields for:
// - type
// - number of fields
// - pointer address (if the field is a struct and has at least one field)
func (s *CopyTestSuite) TestCopy() {
expectedObjs := make(map[string]objInfo)
actualObjs := make(map[string]objInfo)

for _, initStruct := range s.initStructs {
expectedObjs = getObjInfo(initStruct)

switch t := initStruct.(type) {
case ConsumingAnnotationTrigger:
tCopied := t.Copy()
actualObjs = getObjInfo(tCopied)
case Key:
kCopied := t.copy()
actualObjs = getObjInfo(kCopied)
default:
s.Failf("unknown type", "unknown type %T", t)
}

for expectedKey, expectedObj := range expectedObjs {
actualObj, ok := actualObjs[expectedKey]
s.True(ok, "key `%s` should exist in copied struct object", expectedKey)
s.Equal(expectedObj.typ, actualObj.typ, "key `%s` should have the same type after deep copying", expectedKey)
s.Equal(expectedObj.numFields, actualObj.numFields, "key `%s` should have the same number of fields after deep copying", expectedKey)

// Note that Go optimizes the memory allocation of pointers to structs. The pointer address for structs with
// no fields will be the same. E.g., consider struct `S` with no fields, then `s1 := &S{}, s2 := &S{};
// fmt.Printf("%p %p", s1, s2)` will print the same address. Therefore, we only add the pointer address of a struct
// if it has at least one field. The reason for this being that currently, the use of this helper function is used only in
// the `CopyTestSuite` to check that the `Copy` method implementations perform a deep copy, i.e., generates different
// pointer addresses for the copied struct and its fields. We may want to modify this behavior in the future, if needed.
if expectedObj.addr != "" && actualObj.addr != "" && expectedObj.numFields > 0 && actualObj.numFields > 0 {
s.NotEqual(expectedObj.addr, actualObj.addr, "key `%s` should not have the same pointer value after deep copying", expectedKey)
}
}
}
}

// Similar to EqualsTestSuite, this test serves as a sanity check to ensure that all the implemented consumer structs
// are tested in this file. The test fails if there are any structs that are found missing from the expected list.
func (s *CopyTestSuite) TestStructsChecked() {
missedStructs := structsCheckedTestHelper(s.interfaceName, s.packagePath, s.initStructs)
s.Equalf(0, len(missedStructs), "the following structs were not tested: [`%s`]", strings.Join(missedStructs, "`, `"))
}
18 changes: 1 addition & 17 deletions annotation/equals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package annotation

import (
"reflect"
"strings"

"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -88,21 +87,6 @@ func (s *EqualsTestSuite) TestEqualsFalse() {
// using `structsImplementingInterface()`, and finds the actual list of consumer structs that are tested in the
// governing test case. The test fails if there are any structs that are missing from the expected list.
func (s *EqualsTestSuite) TestStructsChecked() {
expected := structsImplementingInterface(s.interfaceName, s.packagePath)
s.NotEmpty(expected, "no structs found implementing `%s` interface", s.interfaceName)

actual := make(map[string]bool)
for _, initStruct := range s.initStructs {
actual[reflect.TypeOf(initStruct).Elem().Name()] = true
}

// compare expected and actual, and find structs that were not tested
var missedStructs []string
for structName := range expected {
if !actual[structName] {
missedStructs = append(missedStructs, structName)
}
}
// if there are any structs that were not tested, fail the test and print the list of structs
missedStructs := structsCheckedTestHelper(s.interfaceName, s.packagePath, s.initStructs)
s.Equalf(0, len(missedStructs), "the following structs were not tested: [`%s`]", strings.Join(missedStructs, "`, `"))
}
36 changes: 36 additions & 0 deletions annotation/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
package annotation

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"reflect"

"github.com/stretchr/testify/mock"
"go.uber.org/nilaway/util"
Expand Down Expand Up @@ -51,6 +53,19 @@ func (m *mockKey) copy() Key {
return args.Get(0).(Key)
}

func newMockKey() *mockKey {
mockedKey := new(mockKey)
mockedKey.ExpectedCalls = nil
mockedKey.On("equals", mock.Anything).Return(true)

copiedMockKey := new(mockKey)
mockedKey.ExpectedCalls = nil
mockedKey.On("equals", mock.Anything).Return(true)

mockedKey.On("copy").Return(copiedMockKey)
return mockedKey
}

// mockProducingAnnotationTrigger is a mock implementation of the ProducingAnnotationTrigger interface
type mockProducingAnnotationTrigger struct {
mock.Mock
Expand Down Expand Up @@ -212,3 +227,24 @@ func structsImplementingInterface(interfaceName string, packageName ...string) m
}
return structs
}

func structsCheckedTestHelper(interfaceName string, packagePath string, initStructs []any) []string {
expected := structsImplementingInterface(interfaceName, packagePath)
if len(expected) == 0 {
panic(fmt.Sprintf("no structs found implementing `%s` interface", interfaceName))
}

actual := make(map[string]bool)
for _, initStruct := range initStructs {
actual[reflect.TypeOf(initStruct).Elem().Name()] = true
}

// compare expected and actual, and find structs that were not tested
var missedStructs []string
for structName := range expected {
if !actual[structName] {
missedStructs = append(missedStructs, structName)
}
}
return missedStructs
}
Loading

0 comments on commit 597e0e6

Please sign in to comment.