Skip to content

Commit

Permalink
feat: Generate DiffMyTargetType functions to generate a changeset bas…
Browse files Browse the repository at this point in the history
…ed on updated fields in models

Closes #11
  • Loading branch information
hlubek committed Nov 18, 2024
1 parent 23c6785 commit ec25aae
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 48 deletions.
23 changes: 23 additions & 0 deletions internal/fixtures/fixture_mytype_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion internal/fixtures/my_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type MyType struct {
// LastTime is a readable, sortable and writable pointer column
LastTime *time.Time `read_col:"my_type.last_time,sortable" write_col:"last_time"`
// LastUpdate is a readable, sortable and writable non-pointer time.Time column that is also named differently
LastUpdate time.Time `read_col:"my_type.updated_at,sortable" write_col:"updated_at"`
LastUpdate time.Time `read_col:"my_type.updated_at,sortable" write_col:"updated_at,nodiff"`
// Donuts is a readable and writable slice of non-pointer Donut structs
Donuts []Donut `read_col:"my_type.donuts" write_col:"donuts,json"`
}
Expand All @@ -34,6 +34,10 @@ type MyEmbeddedType struct {
Buzz bool
}

func (t MyEmbeddedType) Equal(other MyEmbeddedType) bool {
return t.Fizz == other.Fizz && t.Buzz == other.Buzz
}

type Donut struct {
Flavor string
Size int
Expand Down
23 changes: 23 additions & 0 deletions internal/fixtures/repository/mappings_mytype_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

225 changes: 182 additions & 43 deletions internal/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,62 +31,148 @@ func GenerateMapping(f *File, m *StructMapping, goPackage string) (err error) {
return fmt.Errorf("generating ChangeSet struct: %w", err)
}

// Empty() method for ChangeSet
var emptyBlock []Code
for _, fm := range m.FieldMappings {
if fm.WriteColDef != nil {
code := If(Id("c").Dot(firstToUpper(fm.Name)).Op("!=").Nil()).Block(
Return(Lit(false)),
)
emptyBlock = append(emptyBlock, code)
}
generateChangeSetEmpty(f, m, changeSetName)
generateChangeSetToMap(f, m, changeSetName)
generateTargetTypeToChangeSet(f, m, goPackage, changeSetName)
err = generateDiffTargetType(f, m, goPackage, changeSetName)
if err != nil {
return fmt.Errorf("generating Diff function: %w", err)
}
emptyBlock = append(emptyBlock, Return(Lit(true)))

f.Func().Params(
Id("c").Id(changeSetName),
).Id("Empty").Params().Bool().Block(
emptyBlock...,
).Line()
generateDefaultSelectJsonObject(f, m)

// toMap() method for ChangeSet
return nil
}

var toMapBlock []Code
toMapBlock = append(toMapBlock, Id("m").Op(":=").Make(Map(String()).Interface()))
func generateDiffTargetType(f *File, m *StructMapping, goPackage string, changeSetName string) error {
// Generate Diff function
var diffBlock []Code

// For each field that's writable, generate comparison code
for _, fm := range m.FieldMappings {
if fm.WriteColDef != nil {
fieldName := fm.Name

var prepareStmt *Statement
if fm.WriteColDef.ToJSON {
prepareStmt = Id("data").Op(",").Id("_").Op(":=").Qual("encoding/json", "Marshal").Call(Id("c").Dot(fieldName))
}
if fm.WriteColDef == nil {
continue
}
if fm.WriteColDef.NoDiff {
continue
}

mapAssign := Id("m").Index(Lit(fm.WriteColDef.Col)).Op("=")
if fm.WriteColDef.ToJSON {
mapAssign.Id("data")
} else if _, ok := fm.FieldType.(*types.Slice); ok {
// Do not indirect slice values
mapAssign.Id("c").Dot(fieldName)
fieldName := firstToUpper(fm.Name)
sourceField := Id("source").Dot(fieldName)
targetField := Id("target").Dot(fieldName)

var comparison *Statement
switch v := fm.FieldType.(type) {
case *types.Slice:
// For slices, use slices.Equal
comparison = Op("!").Qual("slices", "Equal").Call(sourceField, targetField)

case *types.Map:
// For maps, use maps.Equal
comparison = Op("!").Qual("maps", "Equal").Call(sourceField, targetField)

case *types.Basic:
// For basic types, direct comparison
comparison = sourceField.Op("!=").Add(targetField)

case *types.Pointer:
// For pointers, check if both nil or both non-nil and values equal
comparison = Op("!").Parens(generatePointerComparison(fm, sourceField, targetField))

case *types.Named:
// Check if type implements Equal method
if fm.HasEqual {
// Use Equal method
comparison = Op("!").Add(sourceField).Dot("Equal").Call(targetField)
} else {
mapAssign.Op("*").Id("c").Dot(fieldName)
// Direct comparison
comparison = sourceField.Op("!=").Add(targetField)
}
code := If(Id("c").Dot(fieldName).Op("!=").Nil()).Block(prepareStmt, mapAssign)
toMapBlock = append(toMapBlock, code)

default:
return fmt.Errorf("unsupported field type for diff: %T", v)
}

// If different, set target value in changeset
var assignStmt *Statement
if _, ok := fm.FieldType.(*types.Slice); ok {
// For slices, assign directly
assignStmt = Id("c").Dot(fieldName).Op("=").Add(targetField)
} else {
// For other types, take address
assignStmt = Id("c").Dot(fieldName).Op("=").Op("&").Add(targetField)
}

diffBlock = append(diffBlock, If(comparison).Block(
assignStmt,
))
}

toMapBlock = append(toMapBlock, Return(Id("m")))
diffBlock = append(diffBlock, Return(Id("c")))

f.Func().Params(
Id("c").Id(changeSetName),
).Id("toMap").Params().Map(String()).Interface().Block(
toMapBlock...,
// Add the Diff function to the file
mtp := m.MappingTypePackage
pkgName := m.MappingTypePackage[strings.LastIndex(m.MappingTypePackage, "/")+1:]
if goPackage == pkgName {
mtp = ""
}

f.Func().Id("Diff"+m.TargetName).Params(
Id("source").Qual(mtp, m.MappingTypeName),
Id("target").Qual(mtp, m.MappingTypeName),
).Params(Id("c").Id(changeSetName)).Block(
diffBlock...,
).Line()

// myRecordToChangeSet() function
return nil
}

// generatePointerComparison generates direct pointer comparison code
func generatePointerComparison(fm FieldMapping, sourceField, targetField *Statement) *Statement {
switch fm.FieldType.(*types.Pointer).Elem().(type) {
case *types.Named:
if fm.HasEqual {
// If type has Equal method, use it for comparison
return Parens(
sourceField.Clone().Op("==").Nil().Op("&&").Add(targetField.Clone().Op("==").Nil()),
).Op("||").Parens(
sourceField.Clone().Op("!=").Nil().Op("&&").Add(targetField.Clone().Op("!=").Nil().Op("&&")).
Add(sourceField.Clone().Dot("Equal").Call(Op("*").Add(targetField))),
)
}
}

// Default pointer comparison
return Parens(
sourceField.Clone().Op("==").Nil().Op("&&").Add(targetField.Clone().Op("==").Nil()),
).Op("||").Parens(
sourceField.Clone().Op("!=").Nil().Op("&&").Add(targetField.Clone().Op("!=").Nil()).Op("&&").
Op("*").Add(sourceField).Op("==").Op("*").Add(targetField),
)
}

// typeReference generates the type reference code for a given type
func typeReference(t types.Type) *Statement {
switch v := t.(type) {
case *types.Basic:
return Id(v.String())
case *types.Named:
pkg := v.Obj().Pkg()
if pkg == nil {
return Id(v.Obj().Name())
}
return Qual(pkg.Path(), v.Obj().Name())
case *types.Pointer:
return Op("*").Add(typeReference(v.Elem()))
case *types.Slice:
return Index().Add(typeReference(v.Elem()))
default:
// Add other cases as needed
return Id(t.String())
}
}

// generateTargetTypeToChangeSet generates a myRecordToChangeSet() function
func generateTargetTypeToChangeSet(f *File, m *StructMapping, goPackage string, changeSetName string) {
var toChangeSetBlock []Code

for _, fm := range m.FieldMappings {
Expand Down Expand Up @@ -132,10 +218,63 @@ func GenerateMapping(f *File, m *StructMapping, goPackage string) (err error) {
).Params(Id("c").Id(changeSetName)).Block(
toChangeSetBlock...,
).Line()
}

generateDefaultSelectJsonObject(f, m)
// generateChangeSetToMap generates a toMap() method for ChangeSet
func generateChangeSetToMap(f *File, m *StructMapping, changeSetName string) {
var toMapBlock []Code
toMapBlock = append(toMapBlock, Id("m").Op(":=").Make(Map(String()).Interface()))

return nil
for _, fm := range m.FieldMappings {
if fm.WriteColDef != nil {
fieldName := fm.Name

var prepareStmt *Statement
if fm.WriteColDef.ToJSON {
prepareStmt = Id("data").Op(",").Id("_").Op(":=").Qual("encoding/json", "Marshal").Call(Id("c").Dot(fieldName))
}

mapAssign := Id("m").Index(Lit(fm.WriteColDef.Col)).Op("=")
if fm.WriteColDef.ToJSON {
mapAssign.Id("data")
} else if _, ok := fm.FieldType.(*types.Slice); ok {
// Do not indirect slice values
mapAssign.Id("c").Dot(fieldName)
} else {
mapAssign.Op("*").Id("c").Dot(fieldName)
}
code := If(Id("c").Dot(fieldName).Op("!=").Nil()).Block(prepareStmt, mapAssign)
toMapBlock = append(toMapBlock, code)
}
}

toMapBlock = append(toMapBlock, Return(Id("m")))

f.Func().Params(
Id("c").Id(changeSetName),
).Id("toMap").Params().Map(String()).Interface().Block(
toMapBlock...,
).Line()
}

// generateChangeSetEmpty generates an Empty() method for ChangeSet
func generateChangeSetEmpty(f *File, m *StructMapping, changeSetName string) {
var emptyBlock []Code
for _, fm := range m.FieldMappings {
if fm.WriteColDef != nil {
code := If(Id("c").Dot(firstToUpper(fm.Name)).Op("!=").Nil()).Block(
Return(Lit(false)),
)
emptyBlock = append(emptyBlock, code)
}
}
emptyBlock = append(emptyBlock, Return(Lit(true)))

f.Func().Params(
Id("c").Id(changeSetName),
).Id("Empty").Params().Bool().Block(
emptyBlock...,
).Line()
}

func getBaseFilename(goFile string) string {
Expand Down
Loading

0 comments on commit ec25aae

Please sign in to comment.