Skip to content

Commit

Permalink
optimize code (#579)
Browse files Browse the repository at this point in the history
* optimize code

* optimize returns & unit test
  • Loading branch information
anqiansong authored Mar 27, 2021
1 parent bd623aa commit 8885516
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 73 deletions.
30 changes: 19 additions & 11 deletions tools/goctl/api/javagen/gencomponents.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
}

func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return errors.New("unsupported type %s" + ty.Name())
}

for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return nil
}
}
defineStruct, done, err := c.checkStruct(ty)
if done {
return err
}

modelFile := util.Title(ty.Name()) + ".java"
Expand Down Expand Up @@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
return err
}

func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
}

for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return spec.DefineStruct{}, true, nil
}
}
}
return defineStruct, false, nil
}

func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
var builder strings.Builder
if err := c.writeType(&builder, defineStruct); err != nil {
Expand Down
31 changes: 20 additions & 11 deletions tools/goctl/api/javagen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", err
}

switch valueType {
case "int":
return "Integer[]", nil
case "long":
return "Long[]", nil
case "float":
return "Float[]", nil
case "double":
return "Double[]", nil
case "boolean":
return "Boolean[]", nil
s := getBaseType(valueType)
if len(s) == 0 {
return s, errors.New("unsupported primitive type " + tp.Name())
}

return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
Expand All @@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", errors.New("unsupported primitive type " + tp.Name())
}

func getBaseType(valueType string) string {
switch valueType {
case "int":
return "Integer[]"
case "long":
return "Long[]"
case "float":
return "Float[]"
case "double":
return "Double[]"
case "boolean":
return "Boolean[]"
default:
return ""
}
}

func primitiveType(tp string) (string, bool) {
switch tp {
case "string":
Expand Down
7 changes: 6 additions & 1 deletion tools/goctl/model/sql/command/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"path/filepath"
"testing"

"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"

"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util"
Expand All @@ -19,7 +21,10 @@ var (
)

func TestFromDDl(t *testing.T) {
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
err := gen.Clean()
assert.Nil(t, err)

err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
assert.Equal(t, errNotMatched, err)

// case dir is not exists
Expand Down
44 changes: 23 additions & 21 deletions tools/goctl/model/sql/gen/findonebyfield.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
var list []string
camelTableName := table.Name.ToCamel()
for _, key := range table.UniqueCacheKey {
var inJoin, paramJoin, argJoin Join
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
var in string
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}

var paramJoinString string
if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}

var originalFieldString string
if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
in, paramJoinString, originalFieldString := convertJoin(key)

output, err := t.Execute(map[string]interface{}{
"upperStartCamelObject": camelTableName,
Expand Down Expand Up @@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
}, nil
}

func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
var inJoin, paramJoin, argJoin Join
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}

if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}

if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
return in, paramJoinString, originalFieldString
}
69 changes: 40 additions & 29 deletions tools/goctl/model/sql/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
}
}

checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
}

func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
log := console.NewColorConsole()
uniqueSet := collection.NewSet()
for k, i := range uniqueIndex {
Expand Down Expand Up @@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {

normalIndexSet.Add(joinRet)
}

return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
}

func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
Expand Down Expand Up @@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
}

fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}

field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
fieldM, err := getTableFields(table)
if err != nil {
return nil, err
}

for _, each := range fieldM {
Expand Down Expand Up @@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {

return &reply, nil
}

func getTableFields(table *model.Table) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}

field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
return fieldM, nil
}

0 comments on commit 8885516

Please sign in to comment.