Skip to content

Commit

Permalink
Add deep copy for consumers (#86)
Browse files Browse the repository at this point in the history
This PR modifies the consumers to add support for deep copy when the
consumers are being copied in backpropagation.

[depends on #80 ]
  • Loading branch information
sonalmahajan15 authored Jan 17, 2024
1 parent 6814686 commit b2995c6
Show file tree
Hide file tree
Showing 13 changed files with 1,641 additions and 412 deletions.
813 changes: 683 additions & 130 deletions annotation/consume_trigger.go

Large diffs are not rendered by default.

108 changes: 61 additions & 47 deletions annotation/consume_trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,76 @@ package annotation
import (
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)

type ConsumingAnnotationTriggerTestSuite struct {
EqualsTestSuite
const _interfaceNameConsumingAnnotationTrigger = "ConsumingAnnotationTrigger"

// initStructsConsumingAnnotationTrigger initializes all structs that implement the ConsumingAnnotationTrigger interface
var initStructsConsumingAnnotationTrigger = []any{
&TriggerIfNonNil{Ann: newMockKey()},
&TriggerIfDeepNonNil{Ann: newMockKey()},
&ConsumeTriggerTautology{},
&PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
}

func (s *ConsumingAnnotationTriggerTestSuite) SetupTest() {
s.interfaceName = "ConsumingAnnotationTrigger"
// ConsumingAnnotationTriggerEqualsTestSuite tests for the `equals` method of all the structs that implement
// the `ConsumingAnnotationTrigger` interface.
type ConsumingAnnotationTriggerEqualsTestSuite struct {
EqualsTestSuite
}

mockedKey := new(mockKey)
mockedKey.On("equals", mock.Anything).Return(true)
func (s *ConsumingAnnotationTriggerEqualsTestSuite) SetupTest() {
s.interfaceName = _interfaceNameConsumingAnnotationTrigger
s.initStructs = initStructsConsumingAnnotationTrigger
}

// initialize all structs that implement ConsumingAnnotationTrigger
s.initStructs = []any{
&TriggerIfNonNil{Ann: mockedKey},
&TriggerIfDeepNonNil{Ann: mockedKey},
&ConsumeTriggerTautology{},
&PtrLoad{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&MapWrittenTo{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&SliceAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&FldAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&UseAsErrorResult{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&FldAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&ArgFldPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&GlobalVarAssign{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&ArgPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&RecvPass{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&InterfaceResultFromImplementation{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&MethodParamFromInterface{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsFldOfReturn{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&SliceAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&ArrayAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&PtrAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&MapAssign{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&DeepAssignPrimitive{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ParamAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&FuncRetAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&VariadicParamAssignDeep{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&FieldAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&GlobalVarAssignDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&ChanAccess{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&LocalVarAssignDeep{ConsumeTriggerTautology: &ConsumeTriggerTautology{}},
&ChanSend{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: mockedKey}},
&FldEscape{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsNonErrorRetDependentOnErrorRetNilability{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
&UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: mockedKey}},
}
func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(ConsumingAnnotationTriggerEqualsTestSuite))
}

// TestConsumingAnnotationTriggerEqualsSuite runs the test suite for the `equals` method of all the structs that implement
// ConsumingAnnotationTriggerCopyTestSuite tests for the `copy` method of all the structs that implement
// the `ConsumingAnnotationTrigger` interface.
func TestConsumingAnnotationTriggerEqualsSuite(t *testing.T) {
type ConsumingAnnotationTriggerCopyTestSuite struct {
CopyTestSuite
}

func (s *ConsumingAnnotationTriggerCopyTestSuite) SetupTest() {
s.interfaceName = _interfaceNameConsumingAnnotationTrigger
s.initStructs = initStructsConsumingAnnotationTrigger
}
func TestConsumingAnnotationTriggerCopySuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(ConsumingAnnotationTriggerTestSuite))
suite.Run(t, new(ConsumingAnnotationTriggerCopyTestSuite))
}
121 changes: 121 additions & 0 deletions annotation/copy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package annotation

import (
"fmt"
"reflect"
"strings"

"github.com/stretchr/testify/suite"
)

type CopyTestSuite struct {
suite.Suite
initStructs []any
interfaceName string
packagePath string
}

type objInfo struct {
addr string
numFields int
typ reflect.Type
}

func newObjInfo(addr string, numFields int, typ reflect.Type) objInfo {
return objInfo{
addr: addr,
numFields: numFields,
typ: typ,
}
}

// getObjInfo is a helper function that returns a map of struct and field names to their objInfo.
// The key is in the format of `struct_<struct name>` or `fld_<struct name>.<field name>`.
func getObjInfo(obj any) map[string]objInfo {
ptr := make(map[string]objInfo)

val := reflect.ValueOf(obj).Elem()
ptr[fmt.Sprintf("struct_%s", val.Type().Name())] = newObjInfo(fmt.Sprintf("%p", val.Addr().Interface()), val.NumField(), val.Type())
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
key := fmt.Sprintf("fld_%s.%s", val.Type().Name(), val.Type().Field(i).Name)
if field.Kind() == reflect.Ptr {
if !field.IsZero() {
ptr[key] = newObjInfo(fmt.Sprintf("%p", field.Interface()), field.Elem().NumField(), field.Elem().Type())
}
} else if field.Kind() == reflect.Interface && !field.IsNil() {
// %p cannot be used directly with a reflect.Value, so we need to extract the underlying value first.
interfaceValue := field.Interface()
underlyingValue := reflect.ValueOf(interfaceValue).Elem()
ptr[key] = newObjInfo(fmt.Sprintf("%p", underlyingValue.Addr().Interface()), underlyingValue.NumField(), underlyingValue.Type())
} else {
ptr[key] = newObjInfo("", 0, field.Type())
}
}
return ptr
}

// This test checks that the `Copy` method implementations perform a deep copy, i.e., copies the values but generates
// different pointer addresses for the copied struct and its fields.
// Note that here we cannot use `reflect.DeepEqual` to compare the original and copied structs because reflection
// does not work well with fields with nested struct pointers, giving incorrect results.
// Therefore, we compare the original and copied structs along with their fields for:
// - type
// - number of fields
// - pointer address (if the field is a struct and has at least one field)
func (s *CopyTestSuite) TestCopy() {
var expectedObjs, actualObjs map[string]objInfo

for _, initStruct := range s.initStructs {
var copied any
expectedObjs = getObjInfo(initStruct)

switch t := initStruct.(type) {
case ConsumingAnnotationTrigger:
copied = t.Copy()
actualObjs = getObjInfo(copied)
case Key:
copied = t.copy()
actualObjs = getObjInfo(copied)
default:
s.Failf("unknown type", "unknown type %T", t)
}

for expectedKey, expectedObj := range expectedObjs {
actualObj, ok := actualObjs[expectedKey]
s.True(ok, "key `%s` should exist in copied struct object", expectedKey)
s.Equal(expectedObj.typ, actualObj.typ, "key `%s` should have the same type after deep copying", expectedKey)
s.Equal(expectedObj.numFields, actualObj.numFields, "key `%s` should have the same number of fields after deep copying", expectedKey)

// Note that Go optimizes the memory allocation of pointers to structs. The pointer address for structs with
// no fields will be the same. E.g., consider struct `S` with no fields, then `s1 := &S{}, s2 := &S{};
// fmt.Printf("%p %p", s1, s2)` will print the same address. Therefore, we only add the pointer address of a struct
// if it has at least one field. The reason for this being that currently, the use of this helper function is used only in
// the `CopyTestSuite` to check that the `Copy` method implementations perform a deep copy, i.e., generates different
// pointer addresses for the copied struct and its fields. We may want to modify this behavior in the future, if needed.
if expectedObj.addr != "" && actualObj.addr != "" && expectedObj.numFields > 0 && actualObj.numFields > 0 {
s.NotEqual(expectedObj.addr, actualObj.addr, "key `%s` should not have the same pointer value after deep copying", expectedKey)
}
}
}
}

// Similar to EqualsTestSuite, this test serves as a sanity check to ensure that all the implemented consumer structs
// are tested in this file. The test fails if there are any structs that are found missing from the expected list.
func (s *CopyTestSuite) TestStructsChecked() {
missedStructs := structsCheckedTestHelper(s.interfaceName, s.packagePath, s.initStructs)
s.Equalf(0, len(missedStructs), "the following structs were not tested: [`%s`]", strings.Join(missedStructs, "`, `"))
}
Loading

0 comments on commit b2995c6

Please sign in to comment.